From 0b64ce83c8755367cb109eaf92ea49c87e5e541b Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 13 Jan 2026 16:58:17 -0800 Subject: [PATCH 01/92] skeleton of inference moe layer done --- .../text_generation_controller.py | 7 + megatron/core/models/gpt/moe_module_specs.py | 33 +++- .../core/models/mamba/mamba_layer_specs.py | 14 +- megatron/core/transformer/attention.py | 115 ++++++++---- .../transformer/moe/moe_layer_inference.py | 176 ++++++++++++++++++ megatron/training/arguments.py | 13 ++ 6 files changed, 318 insertions(+), 40 deletions(-) create mode 100644 megatron/core/transformer/moe/moe_layer_inference.py diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index f757d4b539d..035cd0ee899 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -534,6 +534,7 @@ def _dynamic_step_context_init( if model_config.transformer_impl == "inference_optimized": context.maybe_initialize_symmetric_memory() + if nccl_all_reduce_for_prefill and symmetric_ar_type is not None: if context.is_decode_only(): # Turn on symmetric all reduce when in decode mode @@ -778,6 +779,12 @@ def dummy_forward(self): context = self.inference_wrapped_model.inference_context # if no cuda graphs, directly use dummy forward if not context.cuda_graph_batch_dimensions_list: + # initialize symmetric memory if needed + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + model_config = get_model_config(unwrapped_model) + if model_config.transformer_impl == "inference_optimized": + context.maybe_initialize_symmetric_memory() + return self.inference_wrapped_model.dummy_forward() # attempt to use cuda-graph if possible diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 87e4091aece..8f0e9efd3ef 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -15,8 +15,17 @@ def get_moe_module_spec( num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False, + inference_optimized: bool = False, ) -> ModuleSpec: - """Helper function to get module spec for MoE""" + """Helper function to get module spec for MoE + + Args: + use_te: Whether to use Transformer Engine. + num_experts: Number of experts. + moe_grouped_gemm: Whether to use grouped GEMM. + moe_use_legacy_grouped_gemm: Whether to use legacy grouped GEMM. + inference_optimized: If True, use InferenceMoELayer for optimized inference. + """ if use_te is not None and use_te: backend: BackendSpecProvider = TESpecProvider() else: @@ -26,6 +35,7 @@ def get_moe_module_spec( num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + inference_optimized=inference_optimized, ) @@ -35,8 +45,18 @@ def get_moe_module_spec_for_backend( moe_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False, use_te_activation_func: bool = False, + inference_optimized: bool = False, ) -> ModuleSpec: - """Helper function to get module spec for MoE""" + """Helper function to get module spec for MoE + + Args: + backend: Backend spec provider (TE or Local). + num_experts: Number of experts. + moe_grouped_gemm: Whether to use grouped GEMM. + moe_use_legacy_grouped_gemm: Whether to use legacy grouped GEMM. + use_te_activation_func: Whether to use TE activation function. + inference_optimized: If True, use InferenceMoELayer for optimized inference. + """ assert num_experts is not None linear_fc1 = backend.column_parallel_linear() @@ -59,8 +79,15 @@ def get_moe_module_spec_for_backend( # shared experts spec shared_experts = ModuleSpec(module=SharedExpertMLP, submodules=mlp) + # Select MoE layer class based on inference_optimized flag + if inference_optimized: + from megatron.core.transformer.moe.moe_layer_inference import InferenceMoELayer + moe_layer_class = InferenceMoELayer + else: + moe_layer_class = MoELayer + # MoE module spec moe_module_spec = ModuleSpec( - module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) + module=moe_layer_class, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) ) return moe_module_spec diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index f83275ed9c6..e5d5ecffb33 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -22,6 +22,7 @@ from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +# Standard MoE spec (for training) moe = get_moe_module_spec( use_te=True, num_experts=8, # Can be any positive integer (must not be None). @@ -29,6 +30,15 @@ moe_use_legacy_grouped_gemm=False, ) +# Inference-optimized MoE spec +moe_inference = get_moe_module_spec( + use_te=True, + num_experts=8, # Can be any positive integer (must not be None). + moe_grouped_gemm=True, + moe_use_legacy_grouped_gemm=False, + inference_optimized=True, +) + mamba_stack_spec = ModuleSpec( module=MambaStack, submodules=MambaStackSubmodules( @@ -138,10 +148,10 @@ ), ), moe_layer=ModuleSpec( - # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? + # Use inference-optimized MoE layer for better CUDA graph support module=TransformerLayer, submodules=TransformerLayerSubmodules( - pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + pre_mlp_layernorm=TENorm, mlp=moe_inference, mlp_bda=get_bias_dropout_add ), ), ), diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 8265ee83ff5..086df1fe3d9 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -1,5 +1,6 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. import copy +import inspect from abc import ABC, abstractmethod from dataclasses import dataclass from typing import NoReturn, Optional, Tuple, Union @@ -560,6 +561,74 @@ def flash_decode( ) return out + def _flash_attention_3_forward_wrapper( + self, + q: Tensor, + k: Tensor, + v: Tensor, + max_seqlen_q, + max_seqlen_k, + cu_seqlens_q, + seqlens_k, + block_table, + softmax_scale, + ): + """ + Wrapper for calling the FA3 _flash_attn_forward function. + Handles argument conversion for different versions of the _flash_attn_forward API. + """ + kwargs = { + "q": q, + "k": k, + "v": v, + "k_new": None, + "v_new": None, + "qv": None, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": None, + "cu_seqlens_k_new": None, + "seqused_q": None, + "seqused_k": seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "page_table": block_table, + "kv_batch_idx": None, + "leftpad_k": None, + "rotary_cos": None, + "rotary_sin": None, + "seqlens_rotary": None, + "q_descale": None, + "k_descale": None, + "v_descale": None, + "softmax_scale": softmax_scale, + "causal": True, + "attention_chunk": 0, + "softcap": 0.0, + "rotary_interleaved": True, + "scheduler_metadata": None, + "num_splits": 0 if not self.batch_invariant_mode else 1, + "pack_gqa": None, + "sm_margin": 0, + } + + schema = _flash_attn_forward._schema + if "out=" in schema: + kwargs["out"] = None + else: + assert "out_=" in schema + kwargs["out_"] = None + + if "window_size=" in schema: + kwargs["window_size"] = (-1, -1) + else: + assert "window_size_left=" in schema and "window_size_right=" in schema + kwargs["window_size_left"] = -1 + kwargs["window_size_right"] = -1 + + output_total, *unused = _flash_attn_forward(**kwargs) + + return output_total + def flash_decode_and_prefill( self, q: Tensor, @@ -601,40 +670,16 @@ def flash_decode_and_prefill( if HAVE_FA3: # TODO(ksanthanam): Replace with call to flash_attn_varlen_func once # it accepts block_table - output_total, *unused = _flash_attn_forward( - q=q, - k=k, - v=v, - k_new=None, - v_new=None, - qv=None, - out=None, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=None, - cu_seqlens_k_new=None, - seqused_q=None, - seqused_k=seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - page_table=block_table, - kv_batch_idx=None, - leftpad_k=None, - rotary_cos=None, - rotary_sin=None, - seqlens_rotary=None, - q_descale=None, - k_descale=None, - v_descale=None, - softmax_scale=softmax_scale, - causal=True, - window_size=(-1, -1), - attention_chunk=0, - softcap=0.0, - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=0 if not self.batch_invariant_mode else 1, - pack_gqa=None, - sm_margin=0, + output_total = self._flash_attention_3_forward_wrapper( + q, + k, + v, + max_seqlen_q, + max_seqlen_k, + cu_seqlens_q, + seqlens_k, + block_table, + softmax_scale, ) else: assert ( @@ -1496,4 +1541,4 @@ def get_query_key_value_tensors( ) query = query.view(*new_tensor_shape) - return query, key, value + return query, key, value \ No newline at end of file diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py new file mode 100644 index 00000000000..8cbb2c1d4f7 --- /dev/null +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -0,0 +1,176 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +Inference-optimized MoE Layer with AlltoAll Token Dispatcher. + +This implementation inherits from MoELayer to ensure checkpoint compatibility, +while providing a simplified forward pass optimized for inference: +1. Strips out training-specific code (aux losses, recomputation, backward) +2. Uses a simple, linear forward flow +3. Is designed to be CUDA graph compatible (future work) + +The forward pass follows this flow: + Input [S, B, H] + ↓ Route (router gate → topk selection) + probs, routing_map + ↓ Permute (group tokens by expert) + permuted_tokens [num_selected_tokens, H] + ↓ EP AlltoAll (distribute to expert owners) + global_tokens [tokens_on_this_rank, H] + ↓ TP AllGather (if tp_size > 1) + gathered_tokens + ↓ Sort by local expert (if num_local_experts > 1) + sorted_tokens + ↓ Expert FFN (GroupedGEMM) + expert_output + ↓ Unsort by local expert + unsorted_output + ↓ TP ReduceScatter (if tp_size > 1) + scattered_output + ↓ EP AlltoAll (return to original ranks) + combined_output + ↓ Unpermute (restore original order) + Output [S, B, H] + +Usage: + # Load a trained MoELayer checkpoint directly: + inference_layer = InferenceMoELayer(config, submodules, layer_number, pg_collection) + inference_layer.load_state_dict(trained_moe_layer.state_dict()) +""" + +from typing import Optional + +import torch + +from megatron.core import utils +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.transformer_config import TransformerConfig + + +class InferenceMoELayer(MoELayer): + """ + Inference-optimized MoE layer that inherits from MoELayer for checkpoint compatibility. + + This implementation: + - Inherits all weights/submodules from MoELayer (router, experts, token_dispatcher) + - Provides a simplified forward() optimized for inference + - Removes training overhead (aux losses, recomputation, gradient computation) + - Only supports AlltoAll dispatcher (most common for inference) + + Checkpoints trained with MoELayer can be loaded directly. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: Optional[MoESubmodules] = None, + layer_number: Optional[int] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + """ + Initialize the inference MoE layer. + + Args are identical to MoELayer for checkpoint compatibility. + """ + # Initialize parent MoELayer (creates router, experts, token_dispatcher) + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + pg_collection=pg_collection, + ) + + # Validate dispatcher type + if config.moe_token_dispatcher_type != "alltoall": + raise ValueError( + f"InferenceMoELayer only supports 'alltoall' dispatcher, " + f"got '{config.moe_token_dispatcher_type}'" + ) + + # Cache frequently used values + self.hidden_size = config.hidden_size + self.topk = config.moe_router_topk + + # Get process group info from token_dispatcher + self.ep_size = self.token_dispatcher.ep_size + self.ep_rank = utils.get_pg_rank(self.token_dispatcher.ep_group) + self.tp_size = self.token_dispatcher.tp_size + self.tp_rank = self.token_dispatcher.tp_rank + + # Precompute sort indices for multi-expert case + if self.num_local_experts > 1: + input_chunk_idxs = torch.arange( + self.config.num_moe_experts * self.tp_size, device="cuda" + ) + self.sort_input_by_local_experts = input_chunk_idxs.reshape( + -1, self.num_local_experts + ).T.ravel() + self.restore_output_by_local_experts = input_chunk_idxs.reshape( + self.num_local_experts, -1 + ).T.ravel() + + # ==================== Simplified Forward Pass ==================== + def forward(self, hidden_states: torch.Tensor): + """ + Simplified forward pass optimized for inference. + + This overrides MoELayer.forward() with a streamlined version that: + - Removes training overhead (aux losses, recomputation) + - Uses a linear, easy-to-follow flow + - Reuses inherited router, token_dispatcher, and experts + + Args: + hidden_states: [S, B, H] input tensor + + Returns: + Tuple of (output, None) for compatibility with MoELayer interface + """ + print("USED INFERENCE MOE LAYER....") + # Store original shape for restoration + hidden_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_shape[-1]) + num_tokens = hidden_states.shape[0] + + # ===== Step 1: Routing (using inherited router) ===== + # The router returns probs and routing_map + probs, routing_map = self.router(hidden_states) + + # ===== Step 2: Dispatch Preprocess ===== + # Compute metadata and permute tokens by expert assignment + permuted_tokens, permuted_probs = self.token_dispatcher.dispatch_preprocess( + hidden_states, routing_map, probs + ) + + # ===== Step 3: Token Dispatch (EP AlltoAll) ===== + dispatched_tokens, dispatched_probs = self.token_dispatcher.token_dispatch( + permuted_tokens, permuted_probs + ) + + # ===== Step 4: Dispatch Postprocess (TP AllGather + sort by expert) ===== + expert_input, tokens_per_expert, expert_probs = self.token_dispatcher.dispatch_postprocess( + dispatched_tokens, dispatched_probs + ) + + # ===== Step 5: Expert Computation (using inherited experts) ===== + expert_output, mlp_bias = self.experts(expert_input, tokens_per_expert, expert_probs) + + # ===== Step 6: Combine Preprocess (unsort + TP ReduceScatter) ===== + combine_input = self.token_dispatcher.combine_preprocess(expert_output) + + # ===== Step 7: Token Combine (EP AlltoAll reverse) ===== + combined_output = self.token_dispatcher.token_combine(combine_input) + + # ===== Step 8: Combine Postprocess (unpermute to original order) ===== + output = self.token_dispatcher.combine_postprocess(combined_output) + + # Restore original shape + output = output.view(hidden_shape) + + # Handle shared experts (if configured, computed separately) + if self.use_shared_expert and not self.shared_expert_overlap: + shared_output = self.shared_experts(hidden_states.view(hidden_shape)) + output = output + shared_output + + return output, mlp_bias + diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 138711506f3..34fd23cfb23 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1199,6 +1199,19 @@ def validate_args(args, defaults={}): assert args.moe_pad_experts_for_cuda_graph_inference, \ "--moe-pad-experts-for-cuda-graph-inference must be set when using CUDA graphs with expert parallelism" + # Temporary restrictions for inference_optimized MoE with expert parallelism + # until CUDA graph support is properly implemented + if args.transformer_impl == "inference_optimized" and args.expert_model_parallel_size > 1: + assert args.cuda_graph_impl == "none", ( + "CUDA graphs are not yet supported with --transformer-impl inference_optimized " + "and expert parallelism. Please set --cuda-graph-impl none." + ) + assert not args.moe_pad_experts_for_cuda_graph_inference, ( + "--moe-pad-experts-for-cuda-graph-inference is not yet supported with " + "--transformer-impl inference_optimized. This will be enabled once CUDA graph " + "support is properly implemented." + ) + # MoE upcycling check if args.moe_use_upcycling: assert args.save is not None, "When using upcycling, the --save option must be specified." From da29281e53ad8173c634bd6381ba5c2cd0d3b73f Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 23 Jan 2026 08:39:00 -0800 Subject: [PATCH 02/92] restore --- megatron/training/arguments.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 34fd23cfb23..138711506f3 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1199,19 +1199,6 @@ def validate_args(args, defaults={}): assert args.moe_pad_experts_for_cuda_graph_inference, \ "--moe-pad-experts-for-cuda-graph-inference must be set when using CUDA graphs with expert parallelism" - # Temporary restrictions for inference_optimized MoE with expert parallelism - # until CUDA graph support is properly implemented - if args.transformer_impl == "inference_optimized" and args.expert_model_parallel_size > 1: - assert args.cuda_graph_impl == "none", ( - "CUDA graphs are not yet supported with --transformer-impl inference_optimized " - "and expert parallelism. Please set --cuda-graph-impl none." - ) - assert not args.moe_pad_experts_for_cuda_graph_inference, ( - "--moe-pad-experts-for-cuda-graph-inference is not yet supported with " - "--transformer-impl inference_optimized. This will be enabled once CUDA graph " - "support is properly implemented." - ) - # MoE upcycling check if args.moe_use_upcycling: assert args.save is not None, "When using upcycling, the --save option must be specified." From 6e01116ed95d2273d11a1c0c67195d54584fea1d Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 23 Jan 2026 08:54:44 -0800 Subject: [PATCH 03/92] match argument signature with training --- megatron/core/models/gpt/moe_module_specs.py | 2 +- .../core/transformer/moe/moe_layer_inference.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 9d98ab11eb7..648c51d03c1 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -88,7 +88,7 @@ def get_moe_module_spec_for_backend( # MoE module spec moe_module_spec = ModuleSpec( - module=moe_layer_class + module=moe_layer_class, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts), metainfo={"fuse_pre_mlp_layernorm": False}, ) diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index 8cbb2c1d4f7..810939431a9 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -36,6 +36,9 @@ # Load a trained MoELayer checkpoint directly: inference_layer = InferenceMoELayer(config, submodules, layer_number, pg_collection) inference_layer.load_state_dict(trained_moe_layer.state_dict()) + +TODO: Add unit test to verify that InferenceMoELayer.forward() and MoELayer.forward() + have aligned argument signatures (use inspect.signature to compare). """ from typing import Optional @@ -111,7 +114,7 @@ def __init__( ).T.ravel() # ==================== Simplified Forward Pass ==================== - def forward(self, hidden_states: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): """ Simplified forward pass optimized for inference. @@ -122,19 +125,22 @@ def forward(self, hidden_states: torch.Tensor): Args: hidden_states: [S, B, H] input tensor + padding_mask: Optional [B, S] boolean mask. True for valid tokens, False for padding. Returns: Tuple of (output, None) for compatibility with MoELayer interface """ - print("USED INFERENCE MOE LAYER....") # Store original shape for restoration hidden_shape = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_shape[-1]) - num_tokens = hidden_states.shape[0] + + # Transpose padding_mask from [bsz, seq_length] to [seq_length, bsz] to align with hidden_states + if padding_mask is not None: + padding_mask = padding_mask.transpose(0, 1).bool() # ===== Step 1: Routing (using inherited router) ===== # The router returns probs and routing_map - probs, routing_map = self.router(hidden_states) + probs, routing_map = self.router(hidden_states, padding_mask) # ===== Step 2: Dispatch Preprocess ===== # Compute metadata and permute tokens by expert assignment From 153265b42387da35256aafbbab6a6c65d662c948 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 25 Jan 2026 16:47:57 -0800 Subject: [PATCH 04/92] support gpt models like qwen --- gpt_builders.py | 2 +- megatron/core/models/backends.py | 43 +++++++++++++++++-- megatron/core/models/gpt/gpt_layer_specs.py | 28 ++++++++++-- megatron/core/models/gpt/moe_module_specs.py | 14 +++--- .../core/models/mamba/mamba_layer_specs.py | 2 +- 5 files changed, 75 insertions(+), 14 deletions(-) diff --git a/gpt_builders.py b/gpt_builders.py index dfe41f7b88e..bb273211080 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -53,7 +53,7 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ ) ) elif args.num_experts: - assert not (config.transformer_impl == "inference_optimized") + #assert not (config.transformer_impl == "inference_optimized") # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config, diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index 7f84599a04c..94bfde2cd4a 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -9,6 +9,18 @@ from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP from megatron.core.transformer.torch_norm import WrappedTorchNorm +from megatron.core.extensions.transformer_engine import ( + TEActivationOp, + TEColumnParallelGroupedLinear, + TEColumnParallelLinear, + TEDotProductAttention, + TELinear, + TENorm, + TERowParallelGroupedLinear, + TERowParallelLinear, +) +from megatron.core.utils import get_te_version, is_te_min_version +from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP try: import apex # pylint: disable=unused-import @@ -177,6 +189,31 @@ def activation_func(self) -> type: def grouped_mlp_modules( self, moe_use_grouped_gemm: bool, moe_use_legacy_grouped_gemm: bool ) -> Tuple[type, Optional[MLPSubmodules]]: - raise NotImplementedError( - "MOE is not supported with inference optimized transformer implementation." - ) + """Which module and submodules to use for grouped mlp""" + if ( + moe_use_grouped_gemm + and TEColumnParallelGroupedLinear is not None + and not moe_use_legacy_grouped_gemm + ): + return TEGroupedMLP, MLPSubmodules( + linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear + ) + elif moe_use_grouped_gemm: + warnings.warn( + 'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. ' + 'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.' + ) + return GroupedMLP, None + else: + if not is_te_min_version("1.7.0.dev0"): + warnings.warn( + "Only transformer-engine>=1.7.0 supports MoE experts, " + f"but your version is {get_te_version()}. " + "Use local linear implementation instead." + ) + return SequentialMLP, MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ) + return SequentialMLP, MLPSubmodules( + linear_fc1=TEColumnParallelLinear, linear_fc2=TERowParallelLinear + ) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 974e33f88e8..a895d5534bb 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -79,6 +79,9 @@ def get_gpt_layer_with_inference_spec( qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, qk_l2_norm: Optional[bool] = False, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> ModuleSpec: """Use this spec to use inference optimized linear layers. Args: @@ -91,9 +94,9 @@ def get_gpt_layer_with_inference_spec( mlp = get_mlp_module_spec_for_backend( backend=backend, - num_experts=None, - moe_grouped_gemm=False, - moe_use_legacy_grouped_gemm=False, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, use_te_op_fuser=False, use_te_activation_func=False, ) @@ -156,7 +159,7 @@ def get_gpt_layer_with_inference_spec( ), ), self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, sharded_state_dict_keys_map={ @@ -552,6 +555,21 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, ) + elif config.transformer_impl == "inference_optimized": + layer_norm_impl = TENorm + dense_layer_spec = get_gpt_layer_with_inference_spec( + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + qk_l2_norm=qk_l2_norm, + ) + moe_layer_spec = get_gpt_layer_with_inference_spec( + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + qk_l2_norm=qk_l2_norm, + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm + ) else: layer_norm_impl = LNImpl dense_layer_spec = get_gpt_layer_local_spec( @@ -643,6 +661,8 @@ def get_gpt_decoder_block_spec( if use_transformer_engine: layer_norm_impl = TENorm + elif config.transformer_impl == "inference_optimized": + layer_norm_impl = TENorm else: layer_norm_impl = LNImpl # Block spec. diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 648c51d03c1..7f9fc211552 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -3,9 +3,10 @@ from typing import Optional from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider -from megatron.core.models.backends import BackendSpecProvider, LocalSpecProvider +from megatron.core.models.backends import BackendSpecProvider, LocalSpecProvider, InferenceSpecProvider from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.moe_layer_inference import InferenceMoELayer from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec @@ -26,8 +27,13 @@ def get_moe_module_spec( moe_use_legacy_grouped_gemm: Whether to use legacy grouped GEMM. inference_optimized: If True, use InferenceMoELayer for optimized inference. """ + # This function is called my mamba_layer_specs.py + # The GPT layer specs directly calls get_moe_module_spec_for_backend + if use_te is not None and use_te: backend: BackendSpecProvider = TESpecProvider() + elif inference_optimized: + backend = InferenceSpecProvider() else: backend = LocalSpecProvider() return get_moe_module_spec_for_backend( @@ -35,7 +41,6 @@ def get_moe_module_spec( num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, - inference_optimized=inference_optimized, ) @@ -45,7 +50,6 @@ def get_moe_module_spec_for_backend( moe_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False, use_te_activation_func: bool = False, - inference_optimized: bool = False, ) -> ModuleSpec: """Helper function to get module spec for MoE @@ -58,7 +62,8 @@ def get_moe_module_spec_for_backend( inference_optimized: If True, use InferenceMoELayer for optimized inference. """ assert num_experts is not None - + inference_optimized: bool = isinstance(backend, InferenceSpecProvider) + linear_fc1 = backend.column_parallel_linear() linear_fc2 = backend.row_parallel_linear() activation_func = backend.activation_func() @@ -81,7 +86,6 @@ def get_moe_module_spec_for_backend( # Select MoE layer class based on inference_optimized flag if inference_optimized: - from megatron.core.transformer.moe.moe_layer_inference import InferenceMoELayer moe_layer_class = InferenceMoELayer else: moe_layer_class = MoELayer diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index e5d5ecffb33..8cd588840a9 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -32,7 +32,7 @@ # Inference-optimized MoE spec moe_inference = get_moe_module_spec( - use_te=True, + use_te=False, num_experts=8, # Can be any positive integer (must not be None). moe_grouped_gemm=True, moe_use_legacy_grouped_gemm=False, From 7915cff7a8aca5cc67ffb4134b63e93cbed3e222 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 27 Jan 2026 16:27:29 -0800 Subject: [PATCH 05/92] make torch grouped gemm work --- megatron/core/models/backends.py | 4 +- megatron/core/transformer/moe/experts.py | 149 ++++++++++++++++++ .../core/transformer/moe/gpu_resident_ops.py | 110 +++++++++++++ .../transformer/moe/moe_layer_inference.py | 16 +- .../moe/token_dispatcher_inference.py | 139 ++++++++++++++++ 5 files changed, 415 insertions(+), 3 deletions(-) create mode 100644 megatron/core/transformer/moe/gpu_resident_ops.py create mode 100644 megatron/core/transformer/moe/token_dispatcher_inference.py diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index 94bfde2cd4a..4d0eb1f75ed 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -20,7 +20,7 @@ TERowParallelLinear, ) from megatron.core.utils import get_te_version, is_te_min_version -from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.experts import GroupedMLP, InferenceGroupedMLP, SequentialMLP, TEGroupedMLP try: import apex # pylint: disable=unused-import @@ -195,7 +195,7 @@ def grouped_mlp_modules( and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm ): - return TEGroupedMLP, MLPSubmodules( + return InferenceGroupedMLP, MLPSubmodules( linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear ) elif moe_use_grouped_gemm: diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 62fb7a148c8..a886e91abce 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -817,6 +817,155 @@ def backward_dw(self): self.linear_fc1.backward_dw() +class InferenceGroupedMLP(TEGroupedMLP): + """Inference-optimized GroupedMLP using torch._grouped_mm with GPU-resident offsets. + + Inherits from TEGroupedMLP to reuse weight initialization and checkpoint compatibility. + Overrides forward() to use torch._grouped_mm instead of TE's grouped linear, + keeping tokens_per_expert on GPU to avoid host synchronization. + """ + + def __init__( + self, + num_local_experts: int, + config: TransformerConfig, + submodules: MLPSubmodules, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + # Initialize parent TEGroupedMLP (creates linear_fc1, linear_fc2) + super().__init__( + num_local_experts=num_local_experts, + config=config, + submodules=submodules, + pg_collection=pg_collection, + ) + + # Concatenate TE's per-expert weights into single tensors for torch._grouped_mm + # TE GroupedLinear stores weights as weight0, weight1, ..., weight{n-1} + # torch._grouped_mm expects shape [num_experts, out_features, in_features] + self._build_concatenated_weights() + + # Register hook to rebuild concatenated weights after load_state_dict + # self._register_load_state_dict_post_hook(self._rebuild_weights_hook) + + # Set up activation function for inference (simplified, no recompute) + if self.config.gated_linear_unit: + @jit_fuser + def glu(x): + x = torch.chunk(x, 2, dim=-1) + return self.config.activation_func(x[0]) * x[1] + self._inference_activation_func = glu + else: + self._inference_activation_func = self.config.activation_func + + @jit_fuser + def activation_func_with_probs(x, probs): + dtype = x.dtype + res = self._inference_activation_func(x) * probs + return res.to(dtype) + + self._activation_func_with_probs = activation_func_with_probs + + def _build_concatenated_weights(self): + """Create big contiguous weight tensors with per-expert views for checkpoint compatibility. + + Creates _fc1_weight and _fc2_weight as contiguous tensors of shape + [num_experts, out_features, in_features]. Replaces TE's individual weight{i} + parameters with views into these tensors. + + This allows: + - load_state_dict to load into weight{i} views -> writes into big tensor + - forward() to use big tensor directly with torch._grouped_mm + - No post-load hooks needed + """ + # Get device/dtype from existing TE weights + device = self.linear_fc1.weight0.device + dtype = self.linear_fc1.weight0.dtype + + fc1_shape = self.linear_fc1.weight0.shape # [out_features, in_features] + fc2_shape = self.linear_fc2.weight0.shape + + # Create big contiguous tensors + _fc1_weight = torch.empty( + self.num_local_experts, *fc1_shape, device=device, dtype=dtype + ) + _fc2_weight = torch.empty( + self.num_local_experts, *fc2_shape, device=device, dtype=dtype + ) + + # Copy existing TE weights into big tensors, then replace with views + for i in range(self.num_local_experts): + # Copy initialized data + _fc1_weight[i].copy_(getattr(self.linear_fc1, f'weight{i}').data) + _fc2_weight[i].copy_(getattr(self.linear_fc2, f'weight{i}').data) + + # Delete TE's original parameters + delattr(self.linear_fc1, f'weight{i}') + delattr(self.linear_fc2, f'weight{i}') + + # Register views as parameters (checkpoint loads will write into big tensor) + self.linear_fc1.register_parameter( + f'weight{i}', torch.nn.Parameter(_fc1_weight[i]) + ) + self.linear_fc2.register_parameter( + f'weight{i}', torch.nn.Parameter(_fc2_weight[i]) + ) + + # Register big tensors as non-persistent buffers (for .to() device movement, not saved) + self.register_buffer('_fc1_weight', _fc1_weight, persistent=False) + self.register_buffer('_fc2_weight', _fc2_weight, persistent=False) + + def forward( + self, + permuted_local_hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + permuted_probs: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass using torch._grouped_mm with GPU-resident offsets. + + Args: + permuted_local_hidden_states: [total_tokens, hidden_size] input tensor + tokens_per_expert: [num_local_experts] GPU tensor with token counts per expert + permuted_probs: [total_tokens] routing probabilities + + Returns: + Tuple of (output, None) for interface compatibility + """ + permuted_probs = permuted_probs.unsqueeze(-1) + assert tokens_per_expert.is_cuda, "tokens_per_expert must be on GPU" + + if self.config.moe_apply_probs_on_input: + assert ( + self.config.moe_router_topk == 1 + ), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1." + original_dtype = permuted_local_hidden_states.dtype + permuted_local_hidden_states = permuted_probs * permuted_local_hidden_states + permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype) + permuted_probs = torch.ones_like(permuted_probs) + + if permuted_local_hidden_states.nelement() != 0: + # Use pre-concatenated weights (built during init/load) + # _fc1_weight shape: [num_experts, ffn_hidden * (2 if gated else 1), hidden_size] + # _fc2_weight shape: [num_experts, hidden_size, ffn_hidden] + # Compute cumulative offsets on GPU (no host sync!) + # offs[i] = end index of expert i's tokens + offs = tokens_per_expert.cumsum(0).to(torch.int32) + + # FC1: [total_tokens, hidden] @ [num_experts, ffn_hidden, hidden] -> [total_tokens, ffn_hidden] + fc1_output = torch._grouped_mm(permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs) + + # Activation with routing probabilities + intermediate_parallel = self._activation_func_with_probs(fc1_output, permuted_probs) + + # FC2: [total_tokens, ffn_hidden] @ [num_experts, hidden, ffn_hidden] -> [total_tokens, hidden] + fc2_output = torch._grouped_mm(intermediate_parallel, self._fc2_weight.transpose(1, 2), offs=offs) + else: + # No tokens allocated - return empty tensor with correct shape + fc2_output = permuted_local_hidden_states + + return fc2_output, None + + class SequentialMLP(MegatronModule): """An implementation of the Experts layer using a sequence of MLP layers. diff --git a/megatron/core/transformer/moe/gpu_resident_ops.py b/megatron/core/transformer/moe/gpu_resident_ops.py new file mode 100644 index 00000000000..a137a47c7e0 --- /dev/null +++ b/megatron/core/transformer/moe/gpu_resident_ops.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +GPU-resident operations for CUDA-graph compatible MoE inference. + +This module provides GPU-resident implementations of AlltoAll and GroupedGEMM +operations that accept device tensors for split sizes, eliminating host +synchronization points required for CUDA graph compatibility. +""" + +from typing import Optional + +import torch + + +def gpu_resident_all_to_all( + process_group, + input_tensor: torch.Tensor, + output_split_sizes: torch.Tensor, + input_split_sizes: torch.Tensor, +) -> torch.Tensor: + """ + GPU-resident AlltoAll that accepts device tensors for split sizes. + + This function eliminates the host synchronization bottleneck present in + the standard torch.distributed.all_to_all by accepting split sizes as + GPU tensors instead of CPU lists. + + Args: + process_group: The process group for communication + input_tensor: [sum(input_split_sizes), ...] tensor to send + output_split_sizes: [world_size] GPU tensor - number of elements to receive from each rank + input_split_sizes: [world_size] GPU tensor - number of elements to send to each rank + + Returns: + output_tensor: [sum(output_split_sizes), ...] received tensor + + Example: + >>> # Instead of CPU lists: + >>> # output_splits = [100, 200, 150] # CPU list + >>> # input_splits = [80, 120, 200] # CPU list + >>> # output = all_to_all(group, input, output_splits, input_splits) + >>> + >>> # Use GPU tensors: + >>> output_splits = torch.tensor([100, 200, 150], device='cuda') # GPU + >>> input_splits = torch.tensor([80, 120, 200], device='cuda') # GPU + >>> output = gpu_resident_all_to_all(group, input, output_splits, input_splits) + + Implementation notes: + - This is a placeholder for your GPU-resident AlltoAll implementation + - The actual implementation should avoid any .item(), .tolist(), or .cpu() calls + - Split sizes must remain on GPU throughout the operation + - Should support CUDA graph capture + """ + # TODO: Replace with actual GPU-resident AlltoAll implementation + # For now, this is a placeholder showing the expected interface + raise NotImplementedError( + "gpu_resident_all_to_all requires a custom implementation. " + "This placeholder shows the expected API: accepts GPU tensors for split sizes." + ) + + +def gpu_resident_grouped_gemm( + input: torch.Tensor, + weights: torch.Tensor, + tokens_per_expert: torch.Tensor, + use_fp8: bool = False, +) -> torch.Tensor: + """ + GPU-resident GroupedGEMM that accepts device tensor for expert splits. + + This function provides a CUDA-graph compatible grouped GEMM by accepting + tokens_per_expert as a GPU tensor and computing offsets on-device. + + Args: + input: [total_tokens, K] input tensor + weights: [num_experts, K, N] or [num_experts*K, N] weight tensor + tokens_per_expert: [num_experts] GPU tensor - token count per expert + use_fp8: Whether to use FP8 computation (if available) + + Returns: + output: [total_tokens, N] output tensor + + Example: + >>> # Instead of CPU tokens_per_expert: + >>> # tokens_per_expert_cpu = tokens_per_expert.cpu() # Sync! + >>> # offs = tokens_per_expert_cpu.cumsum(0).cuda() # Another sync! + >>> # output = torch._grouped_mm(input, weights, offs=offs) + >>> + >>> # Use GPU-resident version: + >>> output = gpu_resident_grouped_gemm(input, weights, tokens_per_expert) + + Implementation notes: + - This is a placeholder for your GPU-resident GroupedGEMM implementation + - Should compute cumsum(tokens_per_expert) on GPU without host sync + - Must keep all tensors GPU-resident throughout + - Should support CUDA graph capture + - Can wrap torch._grouped_mm or use custom kernel + """ + # TODO: Replace with actual GPU-resident GroupedGEMM implementation + # For now, this is a placeholder showing the expected interface + + # Example of what the implementation might look like: + # offs = tokens_per_expert.cumsum(0).to(torch.int32) # No .cuda() needed! + # return torch._grouped_mm(input, weights, offs=offs) + + raise NotImplementedError( + "gpu_resident_grouped_gemm requires a custom implementation. " + "This placeholder shows the expected API: accepts GPU tensor for tokens_per_expert." + ) diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index 810939431a9..29705fabc3e 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -49,6 +49,10 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.moe.moe_utils import get_default_pg_collection +from megatron.core.transformer.moe.experts import InferenceGroupedMLP + +from .token_dispatcher_inference import InferenceAlltoAllTokenDispatcher class InferenceMoELayer(MoELayer): @@ -77,6 +81,9 @@ def __init__( Args are identical to MoELayer for checkpoint compatibility. """ # Initialize parent MoELayer (creates router, experts, token_dispatcher) + if pg_collection is None: + pg_collection = get_default_pg_collection() + super().__init__( config=config, submodules=submodules, @@ -84,6 +91,13 @@ def __init__( pg_collection=pg_collection, ) + self.token_dispatcher = InferenceAlltoAllTokenDispatcher( + self.num_local_experts, + self.local_expert_indices, + config=self.config, + pg_collection=pg_collection, + ) + # Validate dispatcher type if config.moe_token_dispatcher_type != "alltoall": raise ValueError( @@ -160,7 +174,7 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens # ===== Step 5: Expert Computation (using inherited experts) ===== expert_output, mlp_bias = self.experts(expert_input, tokens_per_expert, expert_probs) - + # ===== Step 6: Combine Preprocess (unsort + TP ReduceScatter) ===== combine_input = self.token_dispatcher.combine_preprocess(expert_output) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py new file mode 100644 index 00000000000..d34b6b67158 --- /dev/null +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -0,0 +1,139 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +Inference-optimized AlltoAll Token Dispatcher with GPU-resident metadata. + +This implementation keeps tokens_per_expert GPU-resident to enable use of +torch._grouped_mm without host synchronization. +""" + +import torch +from typing import List, Optional, Tuple + +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.moe.moe_utils import sort_chunks_by_idxs +from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher +from megatron.core.transformer.transformer_config import TransformerConfig + + +class InferenceAlltoAllTokenDispatcher(MoEAlltoAllTokenDispatcher): + """ + Inference-optimized AlltoAll token dispatcher. + + Key optimization: Returns tokens_per_expert as a GPU tensor (not moved to CPU) + to enable torch._grouped_mm without host synchronization. + + Assumes tp_size == 1 (no tensor parallelism within experts). + """ + + def __init__( + self, + num_local_experts: int, + local_expert_indices: List[int], + config: TransformerConfig, + pg_collection: Optional[ProcessGroupCollection] = None, + ) -> None: + """ + Initialize the inference AlltoAll token dispatcher. + + Args are identical to MoEAlltoAllTokenDispatcher for compatibility. + """ + super().__init__( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=pg_collection, + ) + + def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: + """Preprocess routing map, ensuring tokens_per_expert is created on GPU. + + For drop_and_pad mode, the parent creates tokens_per_expert on CPU via + torch.full() without a device argument. We override to create it directly + on GPU to avoid any host synchronization. + + For non-drop_and_pad mode, the parent creates it on GPU via routing_map.sum(), + so we just call the parent. + """ + if self.drop_and_pad: + # Replicate parent's drop_and_pad logic but create tensor on GPU + from megatron.core.transformer.moe.moe_utils import get_capacity + + num_tokens = routing_map.size(0) * self.config.moe_router_topk + self.capacity = get_capacity( + num_tokens=num_tokens, + num_experts=self.num_experts, + capacity_factor=self.moe_expert_capacity_factor, + ) + self.num_out_tokens = self.capacity * self.num_experts + + # Create on GPU (parent creates on CPU) + num_tokens_per_local_expert = torch.full( + (self.num_local_experts,), + self.capacity * self.tp_size * self.ep_size, + dtype=torch.long, + device=routing_map.device, # Same device as input (GPU) + ) + + self.num_global_tokens_per_local_expert = torch.full( + (self.num_experts * self.tp_size,), + self.capacity, + dtype=torch.long, + device=self.permute_idx_device, + ) + return num_tokens_per_local_expert + else: + # Non-drop_and_pad: parent creates on GPU via routing_map.sum() + return super().preprocess(routing_map) + + def _maybe_dtoh_and_synchronize( + self, point: str, tokens_per_expert: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Move splits to CPU for AlltoAll, but keep tokens_per_expert on GPU. + + The parent class moves all tensors to CPU including tokens_per_expert. + For inference with torch._grouped_mm, we need tokens_per_expert to stay + on GPU to avoid host synchronization. + + This override: + - Still moves input_splits, output_splits, etc. to CPU (required by AlltoAll) + - Still does stream synchronization + - But keeps tokens_per_expert on GPU (for torch._grouped_mm) + """ + from megatron.core.transformer.moe.token_dispatcher import maybe_move_tensor_to_cpu + + if not self.drop_and_pad: + if point == self.cuda_dtoh_point: + # Move splits to CPU (required by torch.distributed.all_to_all_single) + on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream + if on_side_stream: + self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.cuda_dtoh_stream): + # Move AlltoAll splits to CPU (required) + self.input_splits = maybe_move_tensor_to_cpu( + self.input_splits, as_numpy=True, record_stream=on_side_stream + ) + self.output_splits = maybe_move_tensor_to_cpu( + self.output_splits, as_numpy=True, record_stream=on_side_stream + ) + self.output_splits_tp = maybe_move_tensor_to_cpu( + self.output_splits_tp, as_numpy=True, record_stream=on_side_stream + ) + self.num_out_tokens = maybe_move_tensor_to_cpu( + self.num_out_tokens, record_stream=on_side_stream + ) + if self.num_local_experts > 1 and not self.config.moe_permute_fusion: + self.num_global_tokens_per_local_expert = maybe_move_tensor_to_cpu( + self.num_global_tokens_per_local_expert, record_stream=on_side_stream + ) + # NOTE: We intentionally do NOT move tokens_per_expert to CPU here. + # It stays on GPU for use with torch._grouped_mm. + self.d2h_event = self.cuda_dtoh_stream.record_event() + + if point == self.cuda_sync_point: + # Synchronize with the DtoH stream + self.d2h_event.synchronize() + + # Return tokens_per_expert unchanged (stays on GPU!) + return tokens_per_expert + From 8dd410d19bb5c43034494cf8a8c56f55e8f022ee Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 28 Jan 2026 13:32:14 -0800 Subject: [PATCH 06/92] add config restraints for single GPU only and make dtoh and sync a null op --- .../moe/token_dispatcher_inference.py | 79 ++++++++----------- .../core/transformer/transformer_config.py | 30 +++++++ 2 files changed, 64 insertions(+), 45 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index d34b6b67158..d2275848aa1 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -8,10 +8,9 @@ """ import torch -from typing import List, Optional, Tuple +from typing import List, Optional from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.moe.moe_utils import sort_chunks_by_idxs from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher from megatron.core.transformer.transformer_config import TransformerConfig @@ -89,51 +88,41 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: def _maybe_dtoh_and_synchronize( self, point: str, tokens_per_expert: Optional[torch.Tensor] = None ) -> torch.Tensor: - """Move splits to CPU for AlltoAll, but keep tokens_per_expert on GPU. + """No-op for single GPU inference - all metadata stays on GPU. - The parent class moves all tensors to CPU including tokens_per_expert. - For inference with torch._grouped_mm, we need tokens_per_expert to stay - on GPU to avoid host synchronization. + For single GPU (ep_size=1, tp_size=1): + - input_splits, output_splits, output_splits_tp are all None (no AlltoAll needed) + - tokens_per_expert stays on GPU for torch._grouped_mm + - No DtoH transfers or synchronization required - This override: - - Still moves input_splits, output_splits, etc. to CPU (required by AlltoAll) - - Still does stream synchronization - - But keeps tokens_per_expert on GPU (for torch._grouped_mm) + This enables fully CUDA-graphable MoE forward pass. """ - from megatron.core.transformer.moe.token_dispatcher import maybe_move_tensor_to_cpu - - if not self.drop_and_pad: - if point == self.cuda_dtoh_point: - # Move splits to CPU (required by torch.distributed.all_to_all_single) - on_side_stream = torch.cuda.current_stream() != self.cuda_dtoh_stream - if on_side_stream: - self.cuda_dtoh_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(self.cuda_dtoh_stream): - # Move AlltoAll splits to CPU (required) - self.input_splits = maybe_move_tensor_to_cpu( - self.input_splits, as_numpy=True, record_stream=on_side_stream - ) - self.output_splits = maybe_move_tensor_to_cpu( - self.output_splits, as_numpy=True, record_stream=on_side_stream - ) - self.output_splits_tp = maybe_move_tensor_to_cpu( - self.output_splits_tp, as_numpy=True, record_stream=on_side_stream - ) - self.num_out_tokens = maybe_move_tensor_to_cpu( - self.num_out_tokens, record_stream=on_side_stream - ) - if self.num_local_experts > 1 and not self.config.moe_permute_fusion: - self.num_global_tokens_per_local_expert = maybe_move_tensor_to_cpu( - self.num_global_tokens_per_local_expert, record_stream=on_side_stream - ) - # NOTE: We intentionally do NOT move tokens_per_expert to CPU here. - # It stays on GPU for use with torch._grouped_mm. - self.d2h_event = self.cuda_dtoh_stream.record_event() - - if point == self.cuda_sync_point: - # Synchronize with the DtoH stream - self.d2h_event.synchronize() - - # Return tokens_per_expert unchanged (stays on GPU!) + # Validate single GPU assumptions + assert self.ep_size == 1, ( + f"InferenceAlltoAllTokenDispatcher requires ep_size=1, got {self.ep_size}" + ) + assert self.tp_size == 1, ( + f"InferenceAlltoAllTokenDispatcher requires tp_size=1, got {self.tp_size}" + ) + assert self.input_splits is None, ( + "input_splits should be None for single GPU inference" + ) + assert self.output_splits is None, ( + "output_splits should be None for single GPU inference" + ) + assert self.output_splits_tp is None, ( + "output_splits_tp should be None for single GPU inference" + ) + assert not isinstance(self.num_out_tokens, torch.Tensor), ( + "num_out_tokens should be a Python int for dropless single GPU inference, " + f"got {type(self.num_out_tokens)}. Ensure moe_expert_capacity_factor is None " + "and moe_router_padding_for_quantization is False." + ) + assert tokens_per_expert.is_cuda, ( + "tokens_per_expert should be on GPU for single GPU inference" + ) + + + # No DtoH transfers needed - return tokens_per_expert unchanged (stays on GPU!) return tokens_per_expert diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index cabad4e15d7..b33597c887b 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1008,6 +1008,36 @@ def __post_init__(self): if self.expert_model_parallel_size > 1 and self.num_moe_experts is None: raise ValueError("num_moe_experts must be non None to use expert-parallel.") + if self.transformer_impl == "inference_optimized" and ( + self.expert_model_parallel_size * self.expert_tensor_parallel_size > 1 + ): + raise ValueError( + "Inference-optimized MoE layers currently only support data parallelism " + "(expert_model_parallel_size=1 and expert_tensor_parallel_size=1). " + "Multi-GPU support is planned for future work." + ) + + if self.transformer_impl == "inference_optimized" and ( + self.moe_expert_capacity_factor is not None + or self.moe_router_padding_for_quantization + ): + raise ValueError( + "Inference-optimized MoE layers only support dropless MoE " + "(moe_expert_capacity_factor=None and moe_router_padding_for_quantization=False). " + ) + + if self.transformer_impl == "inference_optimized" and self.num_moe_experts is not None: + if not self.moe_permute_fusion: + raise ValueError( + "Inference-optimized MoE layers require moe_permute_fusion=True " + "to use TE fused kernels that support GPU-resident metadata." + ) + if not self.moe_router_fusion: + raise ValueError( + "Inference-optimized MoE layers require moe_router_fusion=True " + "to use TE fused router kernels." + ) + if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError("num_moe_experts must be non-negative.") From b8f5fe5ab301b4ef1fcf51d21aeed3b48b011f5b Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 29 Jan 2026 15:22:22 -0800 Subject: [PATCH 07/92] remove requirement for router fusion --- megatron/core/transformer/transformer_config.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index b33597c887b..329e4b2b65a 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1032,11 +1032,11 @@ def __post_init__(self): "Inference-optimized MoE layers require moe_permute_fusion=True " "to use TE fused kernels that support GPU-resident metadata." ) - if not self.moe_router_fusion: - raise ValueError( - "Inference-optimized MoE layers require moe_router_fusion=True " - "to use TE fused router kernels." - ) + # if not self.moe_router_fusion: + # raise ValueError( + # "Inference-optimized MoE layers require moe_router_fusion=True " + # "to use TE fused router kernels." + # ) if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError("num_moe_experts must be non-negative.") From 5063fb20e30c2f19607544b3997ff01aa3b71c50 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 30 Jan 2026 14:32:43 -0800 Subject: [PATCH 08/92] confirm that this works with nccl all to alls --- megatron/core/transformer/moe/token_dispatcher_inference.py | 6 +++--- megatron/core/transformer/transformer_config.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index d2275848aa1..24778df7c77 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -98,9 +98,9 @@ def _maybe_dtoh_and_synchronize( This enables fully CUDA-graphable MoE forward pass. """ # Validate single GPU assumptions - assert self.ep_size == 1, ( - f"InferenceAlltoAllTokenDispatcher requires ep_size=1, got {self.ep_size}" - ) + # assert self.ep_size == 1, ( + # f"InferenceAlltoAllTokenDispatcher requires ep_size=1, got {self.ep_size}" + # ) assert self.tp_size == 1, ( f"InferenceAlltoAllTokenDispatcher requires tp_size=1, got {self.tp_size}" ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 329e4b2b65a..e44ec506fc0 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1009,7 +1009,7 @@ def __post_init__(self): raise ValueError("num_moe_experts must be non None to use expert-parallel.") if self.transformer_impl == "inference_optimized" and ( - self.expert_model_parallel_size * self.expert_tensor_parallel_size > 1 + self.expert_tensor_parallel_size > 1 ): raise ValueError( "Inference-optimized MoE layers currently only support data parallelism " From 297f92673b096ed1c6520e6489e4ee4717725c58 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Feb 2026 17:28:53 -0800 Subject: [PATCH 09/92] disable drop and pad for inference optimized, and propogate cuda graphed signal downwards --- .../text_generation_controller.py | 11 ++++- megatron/core/inference/utils.py | 16 +++++++ .../transformer/moe/moe_layer_inference.py | 8 ++++ .../moe/token_dispatcher_inference.py | 44 ++++++++++++++++++- megatron/training/arguments.py | 2 +- 5 files changed, 77 insertions(+), 4 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 5d2b93c4a5e..935cdf8d269 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -25,7 +25,7 @@ AbstractModelInferenceWrapper, ) from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding +from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding, set_is_cuda_graphed_iteration_for_ep_inference from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.moe.moe_layer import BaseMoELayer from megatron.core.transformer.utils import set_model_to_sequence_parallel @@ -523,12 +523,20 @@ def _dynamic_step_context_init( moe_pad_experts_for_cuda_graph_inference = ( inference_wrapper_config.moe_pad_experts_for_cuda_graph_inference ) + is_inference_optimized = inference_wrapper_config.transformer_impl == "inference_optimized" + if is_inference_optimized: + assert not moe_pad_experts_for_cuda_graph_inference, ( + "moe_pad_experts_for_cuda_graph_inference cannot be True when " + "transformer_impl is 'inference_optimized'" + ) if moe_pad_experts_for_cuda_graph_inference: if context.using_cuda_graph_this_step(): capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk set_decode_expert_padding(unwrapped_model, True, capacity_factor=capacity_factor) else: set_decode_expert_padding(unwrapped_model, False) + if is_inference_optimized and model_config.expert_model_parallel_size > 1: + set_is_cuda_graphed_iteration_for_ep_inference(unwrapped_model, context.using_cuda_graph_this_step()) # initialize symmetric memory if needed if model_config.transformer_impl == "inference_optimized": @@ -1146,6 +1154,7 @@ def generate_all_output_tokens_static_batch( moe_pad_experts_for_cuda_graph_inference = ( inference_wrapper_config.moe_pad_experts_for_cuda_graph_inference ) + if moe_pad_experts_for_cuda_graph_inference: set_decode_expert_padding(unwrapped_model, False) diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index 55536a52088..947da7c4830 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -131,6 +131,21 @@ def set_decode_expert_padding(model, set_to: bool = False, capacity_factor: int router.config.moe_expert_capacity_factor = capacity_factor router.config.moe_pad_expert_input_to_capacity = bool(set_to) +def set_is_cuda_graphed_iteration_for_ep_inference(model, set_to: bool): + """ + Toggle CUDA graph compatibility for expert parallel inference. + This sets a boolean flag in all InferenceMoELayers to indicate whether + the current iteration is being captured/executed in a CUDA graph. + This allows the dispatcher to adjust its behavior for CUDA graph compatibility, + Args: + - set_to: Enable (True) or disable (False) CUDA graph compatibility. + """ + global moe_layer_cache + if moe_layer_cache is None: + _init_moe_expert_cache(model) + + for moe_layer in moe_layer_cache: + moe_layer.set_is_cuda_graphed_iteration(set_to) def tensor_swap(x, src_idxs, dst_idxs): """ @@ -216,3 +231,4 @@ def shutdown(self): else: asyncio_QueueShutDown = asyncio.QueueShutDown asyncio_Queue = asyncio.Queue + diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index 29705fabc3e..e606a220f5c 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -54,6 +54,7 @@ from .token_dispatcher_inference import InferenceAlltoAllTokenDispatcher +import logging class InferenceMoELayer(MoELayer): """ @@ -127,6 +128,13 @@ def __init__( self.num_local_experts, -1 ).T.ravel() + self.is_cuda_graphed_iteration = False + + def set_is_cuda_graphed_iteration(self, set_to): + self.is_cuda_graphed_iteration = set_to + logging.info("set is cuda graphed iteration to %s", set_to) + exit() + # ==================== Simplified Forward Pass ==================== def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): """ diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 24778df7c77..c147ec39765 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -11,7 +11,10 @@ from typing import List, Optional from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.moe.token_dispatcher import MoEAlltoAllTokenDispatcher +from megatron.core.transformer.moe.token_dispatcher import ( + MoEAlltoAllTokenDispatcher, + MoEAllGatherTokenDispatcher, +) from megatron.core.transformer.transformer_config import TransformerConfig @@ -56,7 +59,7 @@ def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: """ if self.drop_and_pad: # Replicate parent's drop_and_pad logic but create tensor on GPU - from megatron.core.transformer.moe.moe_utils import get_capacity + num_tokens = routing_map.size(0) * self.config.moe_router_topk self.capacity = get_capacity( @@ -126,3 +129,40 @@ def _maybe_dtoh_and_synchronize( # No DtoH transfers needed - return tokens_per_expert unchanged (stays on GPU!) return tokens_per_expert + +class InferenceAllGatherTokenDispatcher(MoEAllGatherTokenDispatcher): + """ + Inference-optimized AllGather token dispatcher. + + This dispatcher uses AllGather instead of AlltoAll for token exchange, + which can be simpler and more efficient for certain configurations. + + Key features: + - Simpler communication pattern (AllGather vs AlltoAll) + - GPU-resident metadata for CUDA graph compatibility + - Assumes tp_size == 1 (no tensor parallelism within experts) + """ + + def __init__( + self, + num_local_experts: int, + local_expert_indices: List[int], + config: TransformerConfig, + pg_collection: Optional[ProcessGroupCollection] = None, + ) -> None: + """ + Initialize the inference AllGather token dispatcher. + + Args: + num_local_experts: Number of experts on this rank. + local_expert_indices: Global indices of experts on this rank. + config: Transformer configuration. + pg_collection: Process group collection for distributed ops. + """ + super().__init__( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=pg_collection, + ) + diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 9fadb3f9900..b9309a5c516 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1260,7 +1260,7 @@ def validate_args(args, defaults={}): assert args.inference_dynamic_batching_buffer_size_gb is not None assert args.inference_dynamic_batching_block_size % 256 == 0, "block size should be a multiple of 256" - if args.cuda_graph_impl == "local" and args.expert_model_parallel_size > 1: + if args.cuda_graph_impl == "local" and args.expert_model_parallel_size > 1 and args.transformer_impl != "inference_optimized": assert args.moe_pad_experts_for_cuda_graph_inference, \ "--moe-pad-experts-for-cuda-graph-inference must be set when using CUDA graphs with expert parallelism" From 629dc1fd577fb3464669432b49e68295dd25c4b1 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Feb 2026 18:06:28 -0800 Subject: [PATCH 10/92] confirm that all-gather dispatch runs within cuda graphs --- megatron/core/transformer/moe/experts.py | 4 +- .../transformer/moe/moe_layer_inference.py | 132 ++++--------- .../moe/token_dispatcher_inference.py | 174 +++++++----------- 3 files changed, 100 insertions(+), 210 deletions(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index a886e91abce..a876b94871a 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -932,7 +932,9 @@ def forward( Tuple of (output, None) for interface compatibility """ permuted_probs = permuted_probs.unsqueeze(-1) - assert tokens_per_expert.is_cuda, "tokens_per_expert must be on GPU" + #assert tokens_per_expert.is_cuda, "tokens_per_expert must be on GPU" + if not tokens_per_expert.is_cuda: + tokens_per_expert = tokens_per_expert.to('cuda') if self.config.moe_apply_probs_on_input: assert ( diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index e606a220f5c..7d80a8a4793 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -50,9 +50,7 @@ from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.moe.moe_utils import get_default_pg_collection -from megatron.core.transformer.moe.experts import InferenceGroupedMLP - -from .token_dispatcher_inference import InferenceAlltoAllTokenDispatcher +from megatron.core.transformer.moe.token_dispatcher_inference import InferenceAllGatherTokenDispatcher import logging @@ -92,113 +90,53 @@ def __init__( pg_collection=pg_collection, ) - self.token_dispatcher = InferenceAlltoAllTokenDispatcher( - self.num_local_experts, - self.local_expert_indices, - config=self.config, - pg_collection=pg_collection, - ) - # Validate dispatcher type + # todo: move this assert to arguments.py or transformer_config.py if config.moe_token_dispatcher_type != "alltoall": raise ValueError( f"InferenceMoELayer only supports 'alltoall' dispatcher, " f"got '{config.moe_token_dispatcher_type}'" ) - # Cache frequently used values - self.hidden_size = config.hidden_size - self.topk = config.moe_router_topk - - # Get process group info from token_dispatcher - self.ep_size = self.token_dispatcher.ep_size - self.ep_rank = utils.get_pg_rank(self.token_dispatcher.ep_group) - self.tp_size = self.token_dispatcher.tp_size - self.tp_rank = self.token_dispatcher.tp_rank - - # Precompute sort indices for multi-expert case - if self.num_local_experts > 1: - input_chunk_idxs = torch.arange( - self.config.num_moe_experts * self.tp_size, device="cuda" - ) - self.sort_input_by_local_experts = input_chunk_idxs.reshape( - -1, self.num_local_experts - ).T.ravel() - self.restore_output_by_local_experts = input_chunk_idxs.reshape( - self.num_local_experts, -1 - ).T.ravel() - self.is_cuda_graphed_iteration = False - + self.inference_token_dispatcher = InferenceAllGatherTokenDispatcher( + self.num_local_experts, + self.local_expert_indices, + config=self.config, + pg_collection=pg_collection, + ) def set_is_cuda_graphed_iteration(self, set_to): self.is_cuda_graphed_iteration = set_to - logging.info("set is cuda graphed iteration to %s", set_to) - exit() - # ==================== Simplified Forward Pass ==================== - def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): - """ - Simplified forward pass optimized for inference. + def activate_inference_token_dispatcher(self): + self.old_token_dispatcher = self.token_dispatcher + self.old_expert_overlap = self.shared_expert_overlap + self.token_dispatcher = self.inference_token_dispatcher + self.shared_expert_overlap = False - This overrides MoELayer.forward() with a streamlined version that: - - Removes training overhead (aux losses, recomputation) - - Uses a linear, easy-to-follow flow - - Reuses inherited router, token_dispatcher, and experts + def deactivate_inference_token_dispatcher(self): + self.token_dispatcher = self.old_token_dispatcher + self.shared_expert_overlap = self.old_expert_overlap - Args: - hidden_states: [S, B, H] input tensor - padding_mask: Optional [B, S] boolean mask. True for valid tokens, False for padding. - - Returns: - Tuple of (output, None) for compatibility with MoELayer interface + # ==================== Simplified Forward Pass ==================== + def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + """ """ - # Store original shape for restoration - hidden_shape = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_shape[-1]) - - # Transpose padding_mask from [bsz, seq_length] to [seq_length, bsz] to align with hidden_states - if padding_mask is not None: - padding_mask = padding_mask.transpose(0, 1).bool() - - # ===== Step 1: Routing (using inherited router) ===== - # The router returns probs and routing_map - probs, routing_map = self.router(hidden_states, padding_mask) - - # ===== Step 2: Dispatch Preprocess ===== - # Compute metadata and permute tokens by expert assignment - permuted_tokens, permuted_probs = self.token_dispatcher.dispatch_preprocess( - hidden_states, routing_map, probs - ) - - # ===== Step 3: Token Dispatch (EP AlltoAll) ===== - dispatched_tokens, dispatched_probs = self.token_dispatcher.token_dispatch( - permuted_tokens, permuted_probs - ) - - # ===== Step 4: Dispatch Postprocess (TP AllGather + sort by expert) ===== - expert_input, tokens_per_expert, expert_probs = self.token_dispatcher.dispatch_postprocess( - dispatched_tokens, dispatched_probs - ) - - # ===== Step 5: Expert Computation (using inherited experts) ===== - expert_output, mlp_bias = self.experts(expert_input, tokens_per_expert, expert_probs) + if not self.is_cuda_graphed_iteration: + # Note: this will still call InferenceGroupedMLP.forward() + # and therefore optimized cutlass grouped gemms. + return super().forward(hidden_states, padding_mask) - # ===== Step 6: Combine Preprocess (unsort + TP ReduceScatter) ===== - combine_input = self.token_dispatcher.combine_preprocess(expert_output) - - # ===== Step 7: Token Combine (EP AlltoAll reverse) ===== - combined_output = self.token_dispatcher.token_combine(combine_input) - - # ===== Step 8: Combine Postprocess (unpermute to original order) ===== - output = self.token_dispatcher.combine_postprocess(combined_output) - - # Restore original shape - output = output.view(hidden_shape) - - # Handle shared experts (if configured, computed separately) - if self.use_shared_expert and not self.shared_expert_overlap: - shared_output = self.shared_experts(hidden_states.view(hidden_shape)) - output = output + shared_output - - return output, mlp_bias + self.activate_inference_token_dispatcher() + assert self.token_dispatcher is self.inference_token_dispatcher + logging.info("activated inference token dispatcher") + + forward_pass_output = super().forward(hidden_states, padding_mask) + + self.deactivate_inference_token_dispatcher() + assert self.token_dispatcher is not self.inference_token_dispatcher + logging.info("deactivated inference token dispatcher") + + return forward_pass_output + diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index c147ec39765..23e0896108f 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -12,123 +12,13 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.moe.token_dispatcher import ( - MoEAlltoAllTokenDispatcher, MoEAllGatherTokenDispatcher, ) from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region -class InferenceAlltoAllTokenDispatcher(MoEAlltoAllTokenDispatcher): - """ - Inference-optimized AlltoAll token dispatcher. - - Key optimization: Returns tokens_per_expert as a GPU tensor (not moved to CPU) - to enable torch._grouped_mm without host synchronization. - - Assumes tp_size == 1 (no tensor parallelism within experts). - """ - - def __init__( - self, - num_local_experts: int, - local_expert_indices: List[int], - config: TransformerConfig, - pg_collection: Optional[ProcessGroupCollection] = None, - ) -> None: - """ - Initialize the inference AlltoAll token dispatcher. - - Args are identical to MoEAlltoAllTokenDispatcher for compatibility. - """ - super().__init__( - num_local_experts=num_local_experts, - local_expert_indices=local_expert_indices, - config=config, - pg_collection=pg_collection, - ) - - def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: - """Preprocess routing map, ensuring tokens_per_expert is created on GPU. - - For drop_and_pad mode, the parent creates tokens_per_expert on CPU via - torch.full() without a device argument. We override to create it directly - on GPU to avoid any host synchronization. - - For non-drop_and_pad mode, the parent creates it on GPU via routing_map.sum(), - so we just call the parent. - """ - if self.drop_and_pad: - # Replicate parent's drop_and_pad logic but create tensor on GPU - - - num_tokens = routing_map.size(0) * self.config.moe_router_topk - self.capacity = get_capacity( - num_tokens=num_tokens, - num_experts=self.num_experts, - capacity_factor=self.moe_expert_capacity_factor, - ) - self.num_out_tokens = self.capacity * self.num_experts - - # Create on GPU (parent creates on CPU) - num_tokens_per_local_expert = torch.full( - (self.num_local_experts,), - self.capacity * self.tp_size * self.ep_size, - dtype=torch.long, - device=routing_map.device, # Same device as input (GPU) - ) - - self.num_global_tokens_per_local_expert = torch.full( - (self.num_experts * self.tp_size,), - self.capacity, - dtype=torch.long, - device=self.permute_idx_device, - ) - return num_tokens_per_local_expert - else: - # Non-drop_and_pad: parent creates on GPU via routing_map.sum() - return super().preprocess(routing_map) - - def _maybe_dtoh_and_synchronize( - self, point: str, tokens_per_expert: Optional[torch.Tensor] = None - ) -> torch.Tensor: - """No-op for single GPU inference - all metadata stays on GPU. - - For single GPU (ep_size=1, tp_size=1): - - input_splits, output_splits, output_splits_tp are all None (no AlltoAll needed) - - tokens_per_expert stays on GPU for torch._grouped_mm - - No DtoH transfers or synchronization required - - This enables fully CUDA-graphable MoE forward pass. - """ - # Validate single GPU assumptions - # assert self.ep_size == 1, ( - # f"InferenceAlltoAllTokenDispatcher requires ep_size=1, got {self.ep_size}" - # ) - assert self.tp_size == 1, ( - f"InferenceAlltoAllTokenDispatcher requires tp_size=1, got {self.tp_size}" - ) - assert self.input_splits is None, ( - "input_splits should be None for single GPU inference" - ) - assert self.output_splits is None, ( - "output_splits should be None for single GPU inference" - ) - assert self.output_splits_tp is None, ( - "output_splits_tp should be None for single GPU inference" - ) - assert not isinstance(self.num_out_tokens, torch.Tensor), ( - "num_out_tokens should be a Python int for dropless single GPU inference, " - f"got {type(self.num_out_tokens)}. Ensure moe_expert_capacity_factor is None " - "and moe_router_padding_for_quantization is False." - ) - assert tokens_per_expert.is_cuda, ( - "tokens_per_expert should be on GPU for single GPU inference" - ) - - - # No DtoH transfers needed - return tokens_per_expert unchanged (stays on GPU!) - return tokens_per_expert - +import logging class InferenceAllGatherTokenDispatcher(MoEAllGatherTokenDispatcher): """ @@ -166,3 +56,63 @@ def __init__( pg_collection=pg_collection, ) + + def token_dispatch(self, hidden_states, probs): + """Gathers tokens from all TP*EP ranks using AllGather.""" + + # Permute the tokens across the expert parallel devices. + if self.tp_size > 1 or self.ep_size > 1: + ## local_indices calculation + with torch.no_grad(): + # [num_local_tokens, num_experts] -> [num_global_tokens, num_experts], where: + # num_local_tokens=(S/TP)*B, num_global_tokens=S*B*EP + self.routing_map = gather_from_sequence_parallel_region( + self.routing_map, group=self.tp_ep_group + ) + + ## local_probs calculation + # max_prob: [S/TP*B, num_experts] -> global_probs: [S*B*EP, num_experts] + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) + # Note that this allgather spans the communication domain of TP*EP. + # [(S/TP)*B, H] -> [((S/TP)*B)*(TP*EP), H] = [S*B*EP, H] + hidden_states = gather_from_sequence_parallel_region( + hidden_states, group=self.tp_ep_group, use_global_buffer=True + ) + + logging.info("Completed token dispatch AllGather.") + exit() + + return hidden_states, probs + + def dispatch_postprocess(self, hidden_states, probs): + """After gathering in token_dispatch, this method identifies tokens for local experts and + permutes them for expert processing. + """ + self.hidden_shape_before_permute = hidden_states.shape + + # The routing map and probs that for local experts. + self.local_map = self.routing_map[ + :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 + ].contiguous() + # probs of global token assignment to local experts. + self.local_probs = probs[ + :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 + ].contiguous() + + tokens_per_expert = self.local_map.sum(dim=0).long().cpu() + + (permuted_local_hidden_states, _, self.reversed_local_input_permutation_mapping) = permute( + hidden_states, + self.local_map, + num_out_tokens=tokens_per_expert.sum(), + fused=self.config.moe_permute_fusion, + ) + + self.local_probs = self.local_probs.T.contiguous().masked_select( + self.local_map.T.contiguous() + ) + self.routing_map = None + return permuted_local_hidden_states, tokens_per_expert, self.local_probs + + + From 21b9140873bcc57b1ddf7ca4dba1e4106df530ba Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Feb 2026 21:26:31 -0800 Subject: [PATCH 11/92] working --- megatron/core/transformer/moe/experts.py | 82 ++++++--- .../core/transformer/moe/inference_kernels.py | 170 ++++++++++++++++++ .../transformer/moe/moe_layer_inference.py | 4 +- .../moe/token_dispatcher_inference.py | 138 ++++++++++++-- 4 files changed, 356 insertions(+), 38 deletions(-) create mode 100644 megatron/core/transformer/moe/inference_kernels.py diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index a876b94871a..ce9502c6441 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -848,24 +848,6 @@ def __init__( # Register hook to rebuild concatenated weights after load_state_dict # self._register_load_state_dict_post_hook(self._rebuild_weights_hook) - # Set up activation function for inference (simplified, no recompute) - if self.config.gated_linear_unit: - @jit_fuser - def glu(x): - x = torch.chunk(x, 2, dim=-1) - return self.config.activation_func(x[0]) * x[1] - self._inference_activation_func = glu - else: - self._inference_activation_func = self.config.activation_func - - @jit_fuser - def activation_func_with_probs(x, probs): - dtype = x.dtype - res = self._inference_activation_func(x) * probs - return res.to(dtype) - - self._activation_func_with_probs = activation_func_with_probs - def _build_concatenated_weights(self): """Create big contiguous weight tensors with per-expert views for checkpoint compatibility. @@ -945,6 +927,65 @@ def forward( permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype) permuted_probs = torch.ones_like(permuted_probs) + def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): + if self.config.use_te_activation_func: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + if permuted_probs is not None: + original_dtype = intermediate_parallel.dtype + intermediate_parallel = intermediate_parallel * permuted_probs + intermediate_parallel = intermediate_parallel.to(original_dtype) + elif self.config.bias_activation_fusion: + if self.activation_func == F.silu and self.config.gated_linear_unit: + # dtype is handled inside the fused kernel + intermediate_parallel = weighted_bias_swiglu_impl( + intermediate_parallel, + bias_parallel, + permuted_probs, + self.config.activation_func_fp8_input_store, + ) + elif self.activation_func == quick_gelu and self.config.gated_linear_unit: + intermediate_parallel = weighted_bias_quick_geglu_impl( + intermediate_parallel, + bias_parallel, + permuted_probs, + self.config.activation_func_fp8_input_store, + self.config.glu_linear_offset, + self.config.activation_func_clamp_value, + ) + else: + raise ValueError( + "Only support fusion of swiglu and quick_gelu in TEGroupedMLP." + ) + elif ( + self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu + ): + assert bias_parallel is None + intermediate_parallel = weighted_squared_relu_impl( + intermediate_parallel, permuted_probs + ) + else: + if self.config.gated_linear_unit: + + def glu(x): + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if (val := self.config.activation_func_clamp_value) is not None: + x_glu = x_glu.clamp(min=None, max=val) + x_linear = x_linear.clamp(min=-val, max=val) + return self.config.activation_func(x_glu) * ( + x_linear + self.config.glu_linear_offset + ) + + intermediate_parallel = glu(intermediate_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel) + original_dtype = intermediate_parallel.dtype + intermediate_parallel = intermediate_parallel * permuted_probs + intermediate_parallel = intermediate_parallel.to(original_dtype) + return intermediate_parallel + + if permuted_local_hidden_states.nelement() != 0: # Use pre-concatenated weights (built during init/load) # _fc1_weight shape: [num_experts, ffn_hidden * (2 if gated else 1), hidden_size] @@ -957,10 +998,11 @@ def forward( fc1_output = torch._grouped_mm(permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs) # Activation with routing probabilities - intermediate_parallel = self._activation_func_with_probs(fc1_output, permuted_probs) + # intermediate_parallel = self._activation_func_with_probs(fc1_output, permuted_probs) + bias_act_output = bias_act_func(fc1_output, None, permuted_probs) # FC2: [total_tokens, ffn_hidden] @ [num_experts, hidden, ffn_hidden] -> [total_tokens, hidden] - fc2_output = torch._grouped_mm(intermediate_parallel, self._fc2_weight.transpose(1, 2), offs=offs) + fc2_output = torch._grouped_mm(bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs) else: # No tokens allocated - return empty tensor with correct shape fc2_output = permuted_local_hidden_states diff --git a/megatron/core/transformer/moe/inference_kernels.py b/megatron/core/transformer/moe/inference_kernels.py new file mode 100644 index 00000000000..70a7d1d97e0 --- /dev/null +++ b/megatron/core/transformer/moe/inference_kernels.py @@ -0,0 +1,170 @@ +import torch +import triton +import triton.language as tl +import pytest +import torch + +@triton.jit +def moe_permute_kernel( + hidden_ptr, mask_ptr_T, dest_idx_ptr, output_ptr, + stride_h_t, num_tokens, hidden_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + if not tl.load(mask_ptr_T + pid): return + + out_row_idx = tl.load(dest_idx_ptr + pid) - 1 + token_idx = pid % num_tokens + + offsets = tl.arange(0, BLOCK_SIZE) + mask_h = offsets < hidden_size + + row_data = tl.load(hidden_ptr + (token_idx * stride_h_t) + offsets, mask=mask_h) + tl.store(output_ptr + (out_row_idx * hidden_size) + offsets, row_data, mask=mask_h) + +@triton.jit +def moe_unpermute_kernel( + permuted_ptr, mask_ptr_T, dest_idx_ptr, output_ptr, + stride_out_t, num_tokens, hidden_size, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + if not tl.load(mask_ptr_T + pid): return + + src_row_idx = tl.load(dest_idx_ptr + pid) - 1 + token_idx = pid % num_tokens + + offsets = tl.arange(0, BLOCK_SIZE) + mask_h = offsets < hidden_size + + # Load as current dtype + row_data = tl.load(permuted_ptr + (src_row_idx * hidden_size) + offsets, mask=mask_h) + + # Cast to float32 for the accumulation to avoid BF16 rounding errors + row_data_fp32 = row_data.to(tl.float32) + + # Atomic add in FP32 (Triton handles the casting/locking) + tl.atomic_add(output_ptr + (token_idx * stride_out_t) + offsets, row_data_fp32, mask=mask_h) + +def launch_moe_kernels(hidden_states, mask, static_buffer, unpermute=False): + T, H = hidden_states.shape + E = mask.size(1) + mask_T = mask.t().contiguous() + dest_indices = torch.cumsum(mask_T.view(-1).long(), dim=0).to(torch.int32) + + grid = (E * T,) + BLOCK_SIZE = triton.next_power_of_2(H) + + if not unpermute: + moe_permute_kernel[grid]( + hidden_states, mask_T, dest_indices, static_buffer, + hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE + ) + else: + # For unpermute, hidden_states is the 'output' we write back into + # ensure that hidden states is zeroed out before accumulation + moe_unpermute_kernel[grid]( + static_buffer, mask_T, dest_indices, hidden_states, + hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE + ) + + +import triton +import triton.language as tl + +@triton.jit +def moe_extract_probs_kernel( + probs_ptr_T, # [Experts, Tokens] (Transposed & Contiguous) + mask_ptr_T, # [Experts, Tokens] (Transposed & Contiguous) + dest_idx_ptr, # [Experts * Tokens] (Cumsum of mask_ptr_T) + out_probs_ptr, # [MAX_OUT] (Static 1D Buffer) + num_tokens, +): + # pid follows Expert-major order: expert_idx * num_tokens + token_idx + pid = tl.program_id(0) + + # 1. Check if this expert-token pair is active + mask_val = tl.load(mask_ptr_T + pid) + if not mask_val: + return + + # 2. Get the destination index in the 1D static buffer + # out_row_idx corresponds to the row index in the permuted hidden states + out_idx = tl.load(dest_idx_ptr + pid) - 1 + + # 3. Load the probability and store it in the static output buffer + prob = tl.load(probs_ptr_T + pid) + tl.store(out_probs_ptr + out_idx, prob) + +def launch_extract_probs(probs, mask, static_prob_buffer): + T, E = probs.shape + + # Match the permutation layout: Experts first + probs_T = probs.t().contiguous() + mask_T = mask.t().contiguous() + + # Reuse the same cumsum logic from your permutation step + dest_indices = torch.cumsum(mask_T.view(-1).long(), dim=0).to(torch.int32) + + grid = (E * T,) + + moe_extract_probs_kernel[grid]( + probs_T, + mask_T, + dest_indices, + static_prob_buffer, + T + ) + return static_prob_buffer + +@pytest.mark.parametrize("T, E, H", [ + (1, 1, 128), # Minimal case + (64, 8, 512), # Standard small + (128, 16, 1024), # Medium + (256, 32, 2048), # Large (LLM Scale) + (1024, 1, 128), # Single Expert + (32, 64, 64), # High expert count +]) +@pytest.mark.parametrize("sparsity", [0.1, 0.5, 0.9]) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_moe_cycle(T, E, H, sparsity, dtype): + device = "cuda" + MAX_OUT = T * E + + # Setup inputs + hidden_states = torch.randn(T, H, device=device, dtype=dtype) * 1e-3 + mask = torch.rand(T, E, device=device) > sparsity + + # We need a small prob scaling for unpermute to be realistic, + # but for pure permutation test, we'll stick to raw values. + static_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) + + # 1. Test Permute + launch_moe_kernels(hidden_states, mask, static_buffer, unpermute=False) + + # Verification of Grouped-by-Expert layout + buffer_idx = 0 + for e_idx in range(E): + for t_idx in range(T): + if mask[t_idx, e_idx]: + assert torch.allclose(static_buffer[buffer_idx], hidden_states[t_idx], atol=1e-5) + buffer_idx += 1 + + assert static_buffer[buffer_idx:].sum() == 0, "Stale data found in buffer tail" + + # 2. Test Un-permute (Gather) + # We'll create a new tensor to receive the gathered data + # (Using zeros because unpermute uses atomic_add to handle Top-K) + output_states = torch.zeros_like(hidden_states) + launch_moe_kernels(output_states, mask, static_buffer, unpermute=True) + + # Validation: If a token went to N experts, it should be N * original_value + expert_counts_per_token = mask.sum(dim=1) + expected_output = hidden_states * expert_counts_per_token.unsqueeze(-1) + + # Instead of allclose, use this for BF16 + if dtype == torch.bfloat16: + # rtol=1.6e-2 is the standard epsilon for bfloat16 + torch.testing.assert_close(output_states, expected_output, rtol=0.016, atol=0.005) + else: + torch.testing.assert_close(output_states, expected_output, rtol=1e-5, atol=1e-5) \ No newline at end of file diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index 7d80a8a4793..f8fb9115d34 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -129,13 +129,13 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens self.activate_inference_token_dispatcher() assert self.token_dispatcher is self.inference_token_dispatcher - logging.info("activated inference token dispatcher") + #logging.info("activated inference token dispatcher") forward_pass_output = super().forward(hidden_states, padding_mask) self.deactivate_inference_token_dispatcher() assert self.token_dispatcher is not self.inference_token_dispatcher - logging.info("deactivated inference token dispatcher") + #logging.info("deactivated inference token dispatcher") return forward_pass_output diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 23e0896108f..30a28a8749b 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -17,6 +17,8 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.tensor_parallel import gather_from_sequence_parallel_region +from megatron.core.transformer.moe.moe_utils import permute +from megatron.core.transformer.moe.inference_kernels import launch_moe_kernels, launch_extract_probs import logging @@ -55,7 +57,7 @@ def __init__( config=config, pg_collection=pg_collection, ) - + self.topk = config.moe_router_topk def token_dispatch(self, hidden_states, probs): """Gathers tokens from all TP*EP ranks using AllGather.""" @@ -76,14 +78,51 @@ def token_dispatch(self, hidden_states, probs): # Note that this allgather spans the communication domain of TP*EP. # [(S/TP)*B, H] -> [((S/TP)*B)*(TP*EP), H] = [S*B*EP, H] hidden_states = gather_from_sequence_parallel_region( - hidden_states, group=self.tp_ep_group, use_global_buffer=True + hidden_states, group=self.tp_ep_group ) - logging.info("Completed token dispatch AllGather.") - exit() - return hidden_states, probs + def test_permute_output(self, hidden_states, permute_output, mask): + # Verification of Grouped-by-Expert layout + E = self.local_map.size(1) + T = hidden_states.size(0) + mask = self.local_map + buffer_idx = 0 + for e_idx in range(E): + for t_idx in range(T): + if mask[t_idx, e_idx]: + assert torch.allclose(permute_output[buffer_idx], hidden_states[t_idx]) + buffer_idx += 1 + + #assert static_buffer[buffer_idx:].sum() == 0, "Stale data found in buffer tail" + + def test_permute_probs_output(self, local_probs, probs_workspace, mask): + """ + Verification of Grouped-by-Expert layout for probabilities. + local_probs: [Tokens, Experts] + probs_workspace: [MAX_OUT, 1] (or [MAX_OUT]) + mask: [Tokens, Experts] boolean mask + """ + T = local_probs.size(0) + E = local_probs.size(1) + + buffer_idx = 0 + # Expert-major traversal (Outer loop: Experts, Inner loop: Tokens) + for e_idx in range(E): + for t_idx in range(T): + if mask[t_idx, e_idx]: + # Extract the expected probability from the source [Tokens, Experts] + expected_prob = local_probs[t_idx, e_idx] + # Using a slightly relaxed atol for BF16 if necessary + actual_prob = probs_workspace[buffer_idx] + assert torch.allclose( + actual_prob, + expected_prob + ), f"Prob mismatch at buffer index {buffer_idx} (Expert {e_idx}, Token {t_idx})" + + buffer_idx += 1 + def dispatch_postprocess(self, hidden_states, probs): """After gathering in token_dispatch, this method identifies tokens for local experts and permutes them for expert processing. @@ -98,21 +137,88 @@ def dispatch_postprocess(self, hidden_states, probs): self.local_probs = probs[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 ].contiguous() + # logging.info(f"Routing map shapre: {self.routing_map.shape}, local_map shape: {self.local_map.shape}, hidden_states shape: {hidden_states.shape}, local_probs shape: {self.local_probs.shape}") + # logging.info(f"Routing map: {self.routing_map}") + # exit() + + # Change 1: Keep tokens_per_expert on GPU for CUDA graph compatibility. + tokens_per_expert = self.local_map.sum(dim=0).long() #.cpu() + #hidden_states = torch.randn_like(hidden_states) # Dummy init for exit() + if False: + (permuted_local_hidden_states, permuted_local_probs, self.reversed_local_input_permutation_mapping) = permute( + hidden_states, + self.local_map, + probs=probs, # Change 2: permute probs as well + num_out_tokens=hidden_states.size(0) * self.topk, # Change 3: accounting for worst case + fused=self.config.moe_permute_fusion, + ) + self.test_permute_output(hidden_states, permuted_local_hidden_states, self.local_map) + self.test_permute_probs_output(self.local_probs, permuted_local_probs, self.local_map) + logging.info("TE: After permute verification for both tokens and probs") + else: + # shape of static_buffer is [hidden_states.shape(0) * min(topk, num_local_experts), hidden_states.shape(1)] + tokens_workspace = torch.zeros( + hidden_states.size(0) * min(self.topk, self.num_local_experts), + hidden_states.size(1), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + launch_moe_kernels(hidden_states, self.local_map, tokens_workspace, unpermute=False) + #self.test_permute_output(hidden_states, tokens_workspace, self.local_map) + #logging.info("Triton: After permute verification in token_dispatcher_inference for tokens") + + probs_workspace = torch.zeros( + self.local_probs.size(0) * min(self.topk, self.num_local_experts), + 1, + dtype=probs.dtype, + device=probs.device, + ) + launch_extract_probs(self.local_probs, self.local_map, probs_workspace) + #self.test_permute_probs_output(self.local_probs, probs_workspace, self.local_map) + #logging.info("Triton: After permute verification in token_dispatcher_inference for probs") - tokens_per_expert = self.local_map.sum(dim=0).long().cpu() + permuted_local_hidden_states = tokens_workspace + permuted_local_probs = probs_workspace.squeeze(-1) + # probs_workspace = torch.zeros( + # local_probs.size(0) * min(self.topk, self.num_local_experts), + # 1, + # dtype=probs.dtype, + # device=probs.device, + # ) - (permuted_local_hidden_states, _, self.reversed_local_input_permutation_mapping) = permute( - hidden_states, - self.local_map, - num_out_tokens=tokens_per_expert.sum(), - fused=self.config.moe_permute_fusion, - ) + # print(probs.shape) + # launch_moe_kernels(probs.unsqueeze(-1), self.local_map, probs_workspace, unpermute=False) + # self.test_permute_output(probs.unsqueeze(-1), probs_workspace, self.local_map) - self.local_probs = self.local_probs.T.contiguous().masked_select( - self.local_map.T.contiguous() - ) + + self.local_probs = permuted_local_probs self.routing_map = None return permuted_local_hidden_states, tokens_per_expert, self.local_probs - + def combine_preprocess(self, permuted_expert_outputs): + """ + Reverses token permutation to restore original ordering. + Handles Top-K summation into original hidden state positions. + """ + # 1. Pre-allocate/Fetch static output buffer + # In a real CUDA Graph, this should be a pre-allocated buffer attribute + # to ensure the data_ptr() remains constant. + unpermuted_hidden = torch.empty( + self.hidden_shape_before_permute, + dtype=permuted_expert_outputs.dtype, + device=permuted_expert_outputs.device + ).zero_() + + # 2. Launch the Un-permute kernel + # This kernel uses 'atomic_add' to gather expert outputs. + # It handles the Expert-grouped -> Token-major transition. + # We use the same self.local_map and self.local_probs we cached during dispatch. + launch_moe_kernels( + unpermuted_hidden, # The [Tokens, Hidden] destination + self.local_map, # The boolean mask [Tokens, Experts] + permuted_expert_outputs, # The [MAX_OUT, Hidden] source + unpermute=True + ) + + return unpermuted_hidden From a786eda206e4aaa3f86b7ea2239a104ca46961ac Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Feb 2026 22:17:37 -0800 Subject: [PATCH 12/92] replace permute/unpermute kernels with triton --- .../moe/token_dispatcher_inference.py | 65 +++++-------------- 1 file changed, 18 insertions(+), 47 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 30a28a8749b..e5fc76f93a4 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -143,53 +143,24 @@ def dispatch_postprocess(self, hidden_states, probs): # Change 1: Keep tokens_per_expert on GPU for CUDA graph compatibility. tokens_per_expert = self.local_map.sum(dim=0).long() #.cpu() - #hidden_states = torch.randn_like(hidden_states) # Dummy init for exit() - if False: - (permuted_local_hidden_states, permuted_local_probs, self.reversed_local_input_permutation_mapping) = permute( - hidden_states, - self.local_map, - probs=probs, # Change 2: permute probs as well - num_out_tokens=hidden_states.size(0) * self.topk, # Change 3: accounting for worst case - fused=self.config.moe_permute_fusion, - ) - self.test_permute_output(hidden_states, permuted_local_hidden_states, self.local_map) - self.test_permute_probs_output(self.local_probs, permuted_local_probs, self.local_map) - logging.info("TE: After permute verification for both tokens and probs") - else: - # shape of static_buffer is [hidden_states.shape(0) * min(topk, num_local_experts), hidden_states.shape(1)] - tokens_workspace = torch.zeros( - hidden_states.size(0) * min(self.topk, self.num_local_experts), - hidden_states.size(1), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - launch_moe_kernels(hidden_states, self.local_map, tokens_workspace, unpermute=False) - #self.test_permute_output(hidden_states, tokens_workspace, self.local_map) - #logging.info("Triton: After permute verification in token_dispatcher_inference for tokens") - - probs_workspace = torch.zeros( - self.local_probs.size(0) * min(self.topk, self.num_local_experts), - 1, - dtype=probs.dtype, - device=probs.device, - ) - launch_extract_probs(self.local_probs, self.local_map, probs_workspace) - #self.test_permute_probs_output(self.local_probs, probs_workspace, self.local_map) - #logging.info("Triton: After permute verification in token_dispatcher_inference for probs") - - permuted_local_hidden_states = tokens_workspace - permuted_local_probs = probs_workspace.squeeze(-1) - # probs_workspace = torch.zeros( - # local_probs.size(0) * min(self.topk, self.num_local_experts), - # 1, - # dtype=probs.dtype, - # device=probs.device, - # ) - - # print(probs.shape) - # launch_moe_kernels(probs.unsqueeze(-1), self.local_map, probs_workspace, unpermute=False) - # self.test_permute_output(probs.unsqueeze(-1), probs_workspace, self.local_map) - + tokens_workspace = torch.zeros( + hidden_states.size(0) * min(self.topk, self.num_local_experts), + hidden_states.size(1), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + launch_moe_kernels(hidden_states, self.local_map, tokens_workspace, unpermute=False) + + probs_workspace = torch.zeros( + self.local_probs.size(0) * min(self.topk, self.num_local_experts), + 1, + dtype=probs.dtype, + device=probs.device, + ) + launch_extract_probs(self.local_probs, self.local_map, probs_workspace) + + permuted_local_hidden_states = tokens_workspace + permuted_local_probs = probs_workspace.squeeze(-1) self.local_probs = permuted_local_probs self.routing_map = None From f6ee32c275a993ba1874b82fb6a8bdc38af2d842 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Feb 2026 00:37:27 -0800 Subject: [PATCH 13/92] minor optimizations --- .../moe/token_dispatcher_inference.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index e5fc76f93a4..821a6d2c1ac 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -137,13 +137,10 @@ def dispatch_postprocess(self, hidden_states, probs): self.local_probs = probs[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 ].contiguous() - # logging.info(f"Routing map shapre: {self.routing_map.shape}, local_map shape: {self.local_map.shape}, hidden_states shape: {hidden_states.shape}, local_probs shape: {self.local_probs.shape}") - # logging.info(f"Routing map: {self.routing_map}") - # exit() # Change 1: Keep tokens_per_expert on GPU for CUDA graph compatibility. tokens_per_expert = self.local_map.sum(dim=0).long() #.cpu() - tokens_workspace = torch.zeros( + tokens_workspace = torch.empty( hidden_states.size(0) * min(self.topk, self.num_local_experts), hidden_states.size(1), dtype=hidden_states.dtype, @@ -151,7 +148,7 @@ def dispatch_postprocess(self, hidden_states, probs): ) launch_moe_kernels(hidden_states, self.local_map, tokens_workspace, unpermute=False) - probs_workspace = torch.zeros( + probs_workspace = torch.empty( self.local_probs.size(0) * min(self.topk, self.num_local_experts), 1, dtype=probs.dtype, @@ -171,17 +168,14 @@ def combine_preprocess(self, permuted_expert_outputs): Reverses token permutation to restore original ordering. Handles Top-K summation into original hidden state positions. """ - # 1. Pre-allocate/Fetch static output buffer - # In a real CUDA Graph, this should be a pre-allocated buffer attribute - # to ensure the data_ptr() remains constant. - unpermuted_hidden = torch.empty( + # 1. Pre-allocate static output buffer + unpermuted_hidden = torch.zeros( self.hidden_shape_before_permute, dtype=permuted_expert_outputs.dtype, device=permuted_expert_outputs.device - ).zero_() + ) # 2. Launch the Un-permute kernel - # This kernel uses 'atomic_add' to gather expert outputs. # It handles the Expert-grouped -> Token-major transition. # We use the same self.local_map and self.local_probs we cached during dispatch. launch_moe_kernels( From 3f7f39dc87bc4c0af44ffadea9e1228a2cf43639 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Feb 2026 01:20:30 -0800 Subject: [PATCH 14/92] one round of optimizations --- .../core/transformer/moe/inference_kernels.py | 222 ++++++++++-------- .../transformer/moe/moe_layer_inference.py | 8 +- .../moe/token_dispatcher_inference.py | 59 +++-- 3 files changed, 162 insertions(+), 127 deletions(-) diff --git a/megatron/core/transformer/moe/inference_kernels.py b/megatron/core/transformer/moe/inference_kernels.py index 70a7d1d97e0..5ef2497b41c 100644 --- a/megatron/core/transformer/moe/inference_kernels.py +++ b/megatron/core/transformer/moe/inference_kernels.py @@ -4,23 +4,6 @@ import pytest import torch -@triton.jit -def moe_permute_kernel( - hidden_ptr, mask_ptr_T, dest_idx_ptr, output_ptr, - stride_h_t, num_tokens, hidden_size, - BLOCK_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - if not tl.load(mask_ptr_T + pid): return - - out_row_idx = tl.load(dest_idx_ptr + pid) - 1 - token_idx = pid % num_tokens - - offsets = tl.arange(0, BLOCK_SIZE) - mask_h = offsets < hidden_size - - row_data = tl.load(hidden_ptr + (token_idx * stride_h_t) + offsets, mask=mask_h) - tl.store(output_ptr + (out_row_idx * hidden_size) + offsets, row_data, mask=mask_h) @triton.jit def moe_unpermute_kernel( @@ -46,76 +29,100 @@ def moe_unpermute_kernel( # Atomic add in FP32 (Triton handles the casting/locking) tl.atomic_add(output_ptr + (token_idx * stride_out_t) + offsets, row_data_fp32, mask=mask_h) -def launch_moe_kernels(hidden_states, mask, static_buffer, unpermute=False): +def launch_unpermute_kernel(hidden_states, mask, static_buffer, + mask_T, dest_indices): + """ + Launch the unpermute kernel. + + Args: + hidden_states: Input tensor for permute, output tensor for unpermute + mask: Boolean mask [T, E] + static_buffer: Output buffer for permute, input buffer for unpermute + unpermute: If True, run unpermute kernel + mask_T: Optional pre-computed transposed mask (avoids recomputation) + dest_indices: Optional pre-computed cumsum indices (avoids recomputation) + """ T, H = hidden_states.shape E = mask.size(1) - mask_T = mask.t().contiguous() - dest_indices = torch.cumsum(mask_T.view(-1).long(), dim=0).to(torch.int32) - + grid = (E * T,) BLOCK_SIZE = triton.next_power_of_2(H) - if not unpermute: - moe_permute_kernel[grid]( - hidden_states, mask_T, dest_indices, static_buffer, - hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE - ) - else: - # For unpermute, hidden_states is the 'output' we write back into - # ensure that hidden states is zeroed out before accumulation - moe_unpermute_kernel[grid]( - static_buffer, mask_T, dest_indices, hidden_states, - hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE - ) + + # For unpermute, hidden_states is the 'output' we write back into + # ensure that hidden states is zeroed out before accumulation + moe_unpermute_kernel[grid]( + static_buffer, mask_T, dest_indices, hidden_states, + hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE + ) -import triton -import triton.language as tl @triton.jit -def moe_extract_probs_kernel( - probs_ptr_T, # [Experts, Tokens] (Transposed & Contiguous) - mask_ptr_T, # [Experts, Tokens] (Transposed & Contiguous) - dest_idx_ptr, # [Experts * Tokens] (Cumsum of mask_ptr_T) - out_probs_ptr, # [MAX_OUT] (Static 1D Buffer) - num_tokens, +def moe_fused_permute_extract_kernel( + hidden_ptr, probs_ptr_T, mask_ptr_T, dest_idx_ptr, + out_hidden_ptr, out_probs_ptr, + stride_h_t, num_tokens, hidden_size, + BLOCK_SIZE: tl.constexpr, ): - # pid follows Expert-major order: expert_idx * num_tokens + token_idx + """ + Fused kernel: permute hidden states AND extract probs in one pass. + + This eliminates the need for separate kernel launches and avoids + recomputing mask_T and dest_indices twice. + """ pid = tl.program_id(0) - - # 1. Check if this expert-token pair is active - mask_val = tl.load(mask_ptr_T + pid) - if not mask_val: + + # Early exit if this expert-token pair is inactive + if not tl.load(mask_ptr_T + pid): return - # 2. Get the destination index in the 1D static buffer - # out_row_idx corresponds to the row index in the permuted hidden states - out_idx = tl.load(dest_idx_ptr + pid) - 1 + out_row_idx = tl.load(dest_idx_ptr + pid) - 1 + token_idx = pid % num_tokens + + # 1. Permute hidden states (vectorized load/store) + offsets = tl.arange(0, BLOCK_SIZE) + mask_h = offsets < hidden_size + row_data = tl.load(hidden_ptr + (token_idx * stride_h_t) + offsets, mask=mask_h) + tl.store(out_hidden_ptr + (out_row_idx * hidden_size) + offsets, row_data, mask=mask_h) - # 3. Load the probability and store it in the static output buffer + # 2. Extract probability (single scalar per program) prob = tl.load(probs_ptr_T + pid) - tl.store(out_probs_ptr + out_idx, prob) + tl.store(out_probs_ptr + out_row_idx, prob) -def launch_extract_probs(probs, mask, static_prob_buffer): - T, E = probs.shape - - # Match the permutation layout: Experts first - probs_T = probs.t().contiguous() + +def launch_fused_permute_and_probs(hidden_states, probs, mask, + hidden_workspace, probs_workspace): + """ + Fused launcher that: + 1. Computes mask_T, probs_T, and dest_indices ONCE + 2. Launches a single fused kernel for both permute + prob extraction + + Returns: + mask_T: Transposed mask (cached for potential reuse in unpermute) + dest_indices: Cumsum indices (cached for potential reuse in unpermute) + """ + T, H = hidden_states.shape + E = mask.size(1) + + # Compute shared intermediates once mask_T = mask.t().contiguous() - - # Reuse the same cumsum logic from your permutation step + probs_T = probs.t().contiguous() dest_indices = torch.cumsum(mask_T.view(-1).long(), dim=0).to(torch.int32) - + grid = (E * T,) - - moe_extract_probs_kernel[grid]( - probs_T, - mask_T, - dest_indices, - static_prob_buffer, - T + BLOCK_SIZE = triton.next_power_of_2(H) + + moe_fused_permute_extract_kernel[grid]( + hidden_states, probs_T, mask_T, dest_indices, + hidden_workspace, probs_workspace, + hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE ) - return static_prob_buffer + + # Return cached values for potential reuse in unpermute + return mask_T, dest_indices + + @pytest.mark.parametrize("T, E, H", [ (1, 1, 128), # Minimal case @@ -127,44 +134,55 @@ def launch_extract_probs(probs, mask, static_prob_buffer): ]) @pytest.mark.parametrize("sparsity", [0.1, 0.5, 0.9]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -def test_moe_cycle(T, E, H, sparsity, dtype): +def test_fused_permute_and_probs(T, E, H, sparsity, dtype): + """ + Test that the fused kernel produces identical results to separate kernels. + """ device = "cuda" MAX_OUT = T * E - + # Setup inputs hidden_states = torch.randn(T, H, device=device, dtype=dtype) * 1e-3 + probs = torch.rand(T, E, device=device, dtype=dtype) mask = torch.rand(T, E, device=device) > sparsity - - # We need a small prob scaling for unpermute to be realistic, - # but for pure permutation test, we'll stick to raw values. - static_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) - - # 1. Test Permute - launch_moe_kernels(hidden_states, mask, static_buffer, unpermute=False) - - # Verification of Grouped-by-Expert layout - buffer_idx = 0 - for e_idx in range(E): - for t_idx in range(T): - if mask[t_idx, e_idx]: - assert torch.allclose(static_buffer[buffer_idx], hidden_states[t_idx], atol=1e-5) - buffer_idx += 1 - - assert static_buffer[buffer_idx:].sum() == 0, "Stale data found in buffer tail" - # 2. Test Un-permute (Gather) - # We'll create a new tensor to receive the gathered data - # (Using zeros because unpermute uses atomic_add to handle Top-K) - output_states = torch.zeros_like(hidden_states) - launch_moe_kernels(output_states, mask, static_buffer, unpermute=True) - - # Validation: If a token went to N experts, it should be N * original_value - expert_counts_per_token = mask.sum(dim=1) - expected_output = hidden_states * expert_counts_per_token.unsqueeze(-1) - - # Instead of allclose, use this for BF16 - if dtype == torch.bfloat16: - # rtol=1.6e-2 is the standard epsilon for bfloat16 - torch.testing.assert_close(output_states, expected_output, rtol=0.016, atol=0.005) - else: - torch.testing.assert_close(output_states, expected_output, rtol=1e-5, atol=1e-5) \ No newline at end of file + # Ensure at least one active token-expert pair + if not mask.any(): + mask[0, 0] = True + + # --- Reference: Separate kernel launches --- + ref_hidden_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) + ref_probs_buffer = torch.zeros(MAX_OUT, device=device, dtype=dtype) + + launch_moe_kernels(hidden_states, mask, ref_hidden_buffer, unpermute=False) + launch_extract_probs(probs, mask, ref_probs_buffer) + + # --- Test: Fused kernel launch --- + fused_hidden_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) + fused_probs_buffer = torch.zeros(MAX_OUT, device=device, dtype=dtype) + + mask_T, dest_indices = launch_fused_permute_and_probs( + hidden_states, probs, mask, + fused_hidden_buffer, fused_probs_buffer + ) + + # --- Verify outputs match --- + num_active = mask.sum().item() + + # Compare hidden states (only active portion) + torch.testing.assert_close( + fused_hidden_buffer[:num_active], + ref_hidden_buffer[:num_active], + rtol=1e-5, atol=1e-5 + ) + + # Compare probs (only active portion) + torch.testing.assert_close( + fused_probs_buffer[:num_active], + ref_probs_buffer[:num_active], + rtol=1e-5, atol=1e-5 + ) + + # Verify cached values are valid for reuse + assert mask_T.shape == (E, T), f"mask_T shape mismatch: {mask_T.shape}" + assert dest_indices.shape == (E * T,), f"dest_indices shape mismatch: {dest_indices.shape}" \ No newline at end of file diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index f8fb9115d34..3a140325dd1 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -109,12 +109,18 @@ def set_is_cuda_graphed_iteration(self, set_to): self.is_cuda_graphed_iteration = set_to def activate_inference_token_dispatcher(self): + # replace the token dispatcher with the inference-optimized version self.old_token_dispatcher = self.token_dispatcher - self.old_expert_overlap = self.shared_expert_overlap self.token_dispatcher = self.inference_token_dispatcher + + # disable shared expert overlap during inference as it is not + # supported in InferenceAllGatherTokenDispatcher + self.old_expert_overlap = self.shared_expert_overlap self.shared_expert_overlap = False def deactivate_inference_token_dispatcher(self): + # restore the original token dispatcher + # and shared expert overlap setting self.token_dispatcher = self.old_token_dispatcher self.shared_expert_overlap = self.old_expert_overlap diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 821a6d2c1ac..4f99447c783 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -17,8 +17,10 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.tensor_parallel import gather_from_sequence_parallel_region -from megatron.core.transformer.moe.moe_utils import permute -from megatron.core.transformer.moe.inference_kernels import launch_moe_kernels, launch_extract_probs +from megatron.core.transformer.moe.inference_kernels import ( + launch_fused_permute_and_probs, + launch_unpermute_kernel, +) import logging @@ -126,47 +128,56 @@ def test_permute_probs_output(self, local_probs, probs_workspace, mask): def dispatch_postprocess(self, hidden_states, probs): """After gathering in token_dispatch, this method identifies tokens for local experts and permutes them for expert processing. + + Optimized to use a fused kernel that: + 1. Computes mask_T and dest_indices ONCE (instead of twice) + 2. Permutes hidden states AND extracts probs in a single kernel launch """ self.hidden_shape_before_permute = hidden_states.shape - # The routing map and probs that for local experts. + # The routing map and probs for local experts. self.local_map = self.routing_map[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 ].contiguous() - # probs of global token assignment to local experts. + # Probs of global token assignment to local experts. self.local_probs = probs[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 ].contiguous() - # Change 1: Keep tokens_per_expert on GPU for CUDA graph compatibility. - tokens_per_expert = self.local_map.sum(dim=0).long() #.cpu() + # Keep tokens_per_expert on GPU for CUDA graph compatibility. + tokens_per_expert = self.local_map.sum(dim=0).long() + + # Pre-allocate workspaces + max_out = hidden_states.size(0) * min(self.topk, self.num_local_experts) tokens_workspace = torch.empty( - hidden_states.size(0) * min(self.topk, self.num_local_experts), - hidden_states.size(1), + max_out, hidden_states.size(1), dtype=hidden_states.dtype, device=hidden_states.device, ) - launch_moe_kernels(hidden_states, self.local_map, tokens_workspace, unpermute=False) - probs_workspace = torch.empty( - self.local_probs.size(0) * min(self.topk, self.num_local_experts), - 1, + max_out, dtype=probs.dtype, device=probs.device, ) - launch_extract_probs(self.local_probs, self.local_map, probs_workspace) - - permuted_local_hidden_states = tokens_workspace - permuted_local_probs = probs_workspace.squeeze(-1) - self.local_probs = permuted_local_probs + # Fused kernel launch: permute hidden states + extract probs in one pass + # Also returns cached mask_T and dest_indices for reuse in unpermute + self._cached_mask_T, self._cached_dest_indices = launch_fused_permute_and_probs( + hidden_states, self.local_probs, self.local_map, + tokens_workspace, probs_workspace + ) + + self.local_probs = probs_workspace self.routing_map = None - return permuted_local_hidden_states, tokens_per_expert, self.local_probs + return tokens_workspace, tokens_per_expert, probs_workspace def combine_preprocess(self, permuted_expert_outputs): """ Reverses token permutation to restore original ordering. Handles Top-K summation into original hidden state positions. + + Uses cached mask_T and dest_indices from dispatch_postprocess to avoid + recomputing them (saves 2 kernel launches). """ # 1. Pre-allocate static output buffer unpermuted_hidden = torch.zeros( @@ -175,14 +186,14 @@ def combine_preprocess(self, permuted_expert_outputs): device=permuted_expert_outputs.device ) - # 2. Launch the Un-permute kernel + # 2. Launch the Un-permute kernel with cached intermediates # It handles the Expert-grouped -> Token-major transition. - # We use the same self.local_map and self.local_probs we cached during dispatch. - launch_moe_kernels( - unpermuted_hidden, # The [Tokens, Hidden] destination - self.local_map, # The boolean mask [Tokens, Experts] + launch_unpermute_kernel( + unpermuted_hidden, # The [Tokens, Hidden] destination + self.local_map, # The boolean mask [Tokens, Experts] permuted_expert_outputs, # The [MAX_OUT, Hidden] source - unpermute=True + self._cached_mask_T, + self._cached_dest_indices, ) return unpermuted_hidden From 10da287dbe44d1a6e655e2f8d423a644759fb648 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Feb 2026 01:42:52 -0800 Subject: [PATCH 15/92] reduce kernel calls --- .../core/transformer/moe/inference_kernels.py | 88 +++++++++++-------- .../moe/token_dispatcher_inference.py | 44 +++++----- 2 files changed, 73 insertions(+), 59 deletions(-) diff --git a/megatron/core/transformer/moe/inference_kernels.py b/megatron/core/transformer/moe/inference_kernels.py index 5ef2497b41c..352f1442b60 100644 --- a/megatron/core/transformer/moe/inference_kernels.py +++ b/megatron/core/transformer/moe/inference_kernels.py @@ -29,26 +29,22 @@ def moe_unpermute_kernel( # Atomic add in FP32 (Triton handles the casting/locking) tl.atomic_add(output_ptr + (token_idx * stride_out_t) + offsets, row_data_fp32, mask=mask_h) -def launch_unpermute_kernel(hidden_states, mask, static_buffer, - mask_T, dest_indices): +def launch_unpermute_kernel(hidden_states, static_buffer, mask_T, dest_indices): """ Launch the unpermute kernel. Args: - hidden_states: Input tensor for permute, output tensor for unpermute - mask: Boolean mask [T, E] - static_buffer: Output buffer for permute, input buffer for unpermute - unpermute: If True, run unpermute kernel - mask_T: Optional pre-computed transposed mask (avoids recomputation) - dest_indices: Optional pre-computed cumsum indices (avoids recomputation) + hidden_states: [T, H] output tensor to accumulate into (should be zeroed) + static_buffer: [max_out, H] permuted expert outputs + mask_T: [E, T] pre-transposed mask (reused from dispatch) + dest_indices: [E*T] cumsum indices (reused from dispatch) """ T, H = hidden_states.shape - E = mask.size(1) + E = mask_T.size(0) # mask_T is [E, T] grid = (E * T,) BLOCK_SIZE = triton.next_power_of_2(H) - # For unpermute, hidden_states is the 'output' we write back into # ensure that hidden states is zeroed out before accumulation moe_unpermute_kernel[grid]( @@ -60,16 +56,17 @@ def launch_unpermute_kernel(hidden_states, mask, static_buffer, @triton.jit def moe_fused_permute_extract_kernel( - hidden_ptr, probs_ptr_T, mask_ptr_T, dest_idx_ptr, + hidden_ptr, probs_ptr, mask_ptr_T, dest_idx_ptr, out_hidden_ptr, out_probs_ptr, - stride_h_t, num_tokens, hidden_size, + stride_h_t, stride_probs_t, stride_probs_e, + num_tokens, hidden_size, BLOCK_SIZE: tl.constexpr, ): """ Fused kernel: permute hidden states AND extract probs in one pass. - This eliminates the need for separate kernel launches and avoids - recomputing mask_T and dest_indices twice. + This kernel avoids transposing probs by using stride-based indexing. + The mask is expected to be pre-transposed [E, T] for efficient expert-major access. """ pid = tl.program_id(0) @@ -79,6 +76,7 @@ def moe_fused_permute_extract_kernel( out_row_idx = tl.load(dest_idx_ptr + pid) - 1 token_idx = pid % num_tokens + expert_idx = pid // num_tokens # 1. Permute hidden states (vectorized load/store) offsets = tl.arange(0, BLOCK_SIZE) @@ -86,41 +84,48 @@ def moe_fused_permute_extract_kernel( row_data = tl.load(hidden_ptr + (token_idx * stride_h_t) + offsets, mask=mask_h) tl.store(out_hidden_ptr + (out_row_idx * hidden_size) + offsets, row_data, mask=mask_h) - # 2. Extract probability (single scalar per program) - prob = tl.load(probs_ptr_T + pid) + # 2. Extract probability using stride-based indexing (avoids probs transpose) + # probs is [T, E], so index as probs[token_idx, expert_idx] + prob = tl.load(probs_ptr + token_idx * stride_probs_t + expert_idx * stride_probs_e) tl.store(out_probs_ptr + out_row_idx, prob) -def launch_fused_permute_and_probs(hidden_states, probs, mask, +def launch_fused_permute_and_probs(hidden_states, probs, mask_T, hidden_workspace, probs_workspace): """ Fused launcher that: - 1. Computes mask_T, probs_T, and dest_indices ONCE - 2. Launches a single fused kernel for both permute + prob extraction + 1. Accepts pre-transposed mask_T [E, T] (caller fuses slice+transpose) + 2. Uses stride-based probs access (no transpose needed) + 3. Launches a single fused kernel for both permute + prob extraction + + Args: + hidden_states: [T, H] input hidden states + probs: [T, E] routing probabilities (NOT transposed) + mask_T: [E, T] pre-transposed routing mask (caller provides this) + hidden_workspace: [max_out, H] output buffer for permuted hidden states + probs_workspace: [max_out] output buffer for extracted probs Returns: - mask_T: Transposed mask (cached for potential reuse in unpermute) dest_indices: Cumsum indices (cached for potential reuse in unpermute) """ T, H = hidden_states.shape - E = mask.size(1) + E = mask_T.size(0) # mask_T is [E, T] - # Compute shared intermediates once - mask_T = mask.t().contiguous() - probs_T = probs.t().contiguous() + # Only compute dest_indices (mask_T is provided by caller) dest_indices = torch.cumsum(mask_T.view(-1).long(), dim=0).to(torch.int32) grid = (E * T,) BLOCK_SIZE = triton.next_power_of_2(H) moe_fused_permute_extract_kernel[grid]( - hidden_states, probs_T, mask_T, dest_indices, + hidden_states, probs, mask_T, dest_indices, hidden_workspace, probs_workspace, - hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE + hidden_states.stride(0), probs.stride(0), probs.stride(1), + T, H, BLOCK_SIZE=BLOCK_SIZE ) - # Return cached values for potential reuse in unpermute - return mask_T, dest_indices + # Return cached dest_indices for potential reuse in unpermute + return dest_indices @@ -136,7 +141,7 @@ def launch_fused_permute_and_probs(hidden_states, probs, mask, @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_fused_permute_and_probs(T, E, H, sparsity, dtype): """ - Test that the fused kernel produces identical results to separate kernels. + Test that the fused kernel produces identical results to reference implementation. """ device = "cuda" MAX_OUT = T * E @@ -150,25 +155,33 @@ def test_fused_permute_and_probs(T, E, H, sparsity, dtype): if not mask.any(): mask[0, 0] = True - # --- Reference: Separate kernel launches --- + # Pre-transpose mask (simulating the fused slice+transpose in dispatcher) + mask_T = mask.t().contiguous() # [E, T] + + # --- Reference: Python-based verification --- + num_active = int(mask.sum().item()) ref_hidden_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) ref_probs_buffer = torch.zeros(MAX_OUT, device=device, dtype=dtype) - launch_moe_kernels(hidden_states, mask, ref_hidden_buffer, unpermute=False) - launch_extract_probs(probs, mask, ref_probs_buffer) + # Expert-major ordering reference + buffer_idx = 0 + for e_idx in range(E): + for t_idx in range(T): + if mask[t_idx, e_idx]: + ref_hidden_buffer[buffer_idx] = hidden_states[t_idx] + ref_probs_buffer[buffer_idx] = probs[t_idx, e_idx] + buffer_idx += 1 # --- Test: Fused kernel launch --- fused_hidden_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) fused_probs_buffer = torch.zeros(MAX_OUT, device=device, dtype=dtype) - mask_T, dest_indices = launch_fused_permute_and_probs( - hidden_states, probs, mask, + dest_indices = launch_fused_permute_and_probs( + hidden_states, probs, mask_T, fused_hidden_buffer, fused_probs_buffer ) # --- Verify outputs match --- - num_active = mask.sum().item() - # Compare hidden states (only active portion) torch.testing.assert_close( fused_hidden_buffer[:num_active], @@ -183,6 +196,5 @@ def test_fused_permute_and_probs(T, E, H, sparsity, dtype): rtol=1e-5, atol=1e-5 ) - # Verify cached values are valid for reuse - assert mask_T.shape == (E, T), f"mask_T shape mismatch: {mask_T.shape}" + # Verify dest_indices shape assert dest_indices.shape == (E * T,), f"dest_indices shape mismatch: {dest_indices.shape}" \ No newline at end of file diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 4f99447c783..bd040870fc8 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -129,23 +129,26 @@ def dispatch_postprocess(self, hidden_states, probs): """After gathering in token_dispatch, this method identifies tokens for local experts and permutes them for expert processing. - Optimized to use a fused kernel that: - 1. Computes mask_T and dest_indices ONCE (instead of twice) - 2. Permutes hidden states AND extracts probs in a single kernel launch + Optimized to: + 1. Fuse slice + transpose for mask (single kernel instead of two) + 2. Use stride-based probs access in kernel (avoids probs transpose entirely) + 3. Permute hidden states AND extract probs in a single kernel launch """ self.hidden_shape_before_permute = hidden_states.shape - # The routing map and probs for local experts. - self.local_map = self.routing_map[ + # Fuse slice + transpose for mask: [T, num_experts] -> [num_local_experts, T] + # This produces mask_T directly, avoiding a separate transpose kernel + self._cached_mask_T = self.routing_map[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 - ].contiguous() - # Probs of global token assignment to local experts. - self.local_probs = probs[ + ].t().contiguous() # [E, T] layout + + # Probs: just slice, no transpose needed (kernel uses stride-based access) + local_probs = probs[ :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 - ].contiguous() + ].contiguous() # [T, E] layout - # Keep tokens_per_expert on GPU for CUDA graph compatibility. - tokens_per_expert = self.local_map.sum(dim=0).long() + # tokens_per_expert from transposed mask: sum over tokens (dim=1) for each expert + tokens_per_expert = self._cached_mask_T.sum(dim=1) # Pre-allocate workspaces max_out = hidden_states.size(0) * min(self.topk, self.num_local_experts) @@ -161,14 +164,14 @@ def dispatch_postprocess(self, hidden_states, probs): ) # Fused kernel launch: permute hidden states + extract probs in one pass - # Also returns cached mask_T and dest_indices for reuse in unpermute - self._cached_mask_T, self._cached_dest_indices = launch_fused_permute_and_probs( - hidden_states, self.local_probs, self.local_map, + # Pass mask_T directly (already transposed), probs as [T, E] (kernel uses strides) + self._cached_dest_indices = launch_fused_permute_and_probs( + hidden_states, local_probs, self._cached_mask_T, tokens_workspace, probs_workspace ) - self.local_probs = probs_workspace self.routing_map = None + self.local_probs = probs_workspace return tokens_workspace, tokens_per_expert, probs_workspace def combine_preprocess(self, permuted_expert_outputs): @@ -179,7 +182,7 @@ def combine_preprocess(self, permuted_expert_outputs): Uses cached mask_T and dest_indices from dispatch_postprocess to avoid recomputing them (saves 2 kernel launches). """ - # 1. Pre-allocate static output buffer + # 1. Pre-allocate static output buffer (zeros for atomic accumulation) unpermuted_hidden = torch.zeros( self.hidden_shape_before_permute, dtype=permuted_expert_outputs.dtype, @@ -189,11 +192,10 @@ def combine_preprocess(self, permuted_expert_outputs): # 2. Launch the Un-permute kernel with cached intermediates # It handles the Expert-grouped -> Token-major transition. launch_unpermute_kernel( - unpermuted_hidden, # The [Tokens, Hidden] destination - self.local_map, # The boolean mask [Tokens, Experts] - permuted_expert_outputs, # The [MAX_OUT, Hidden] source - self._cached_mask_T, - self._cached_dest_indices, + unpermuted_hidden, # The [T, H] destination + permuted_expert_outputs, # The [max_out, H] source + self._cached_mask_T, # Cached [E, T] mask + self._cached_dest_indices # Cached cumsum indices ) return unpermuted_hidden From 1983688892b9bb44d3c5e0ecb8bb957792789050 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Feb 2026 03:49:30 -0800 Subject: [PATCH 16/92] symmetric memory AG for hidden states --- megatron/core/parallel_state.py | 42 ++++-- .../core/tensor_parallel/inference_layers.py | 6 +- .../moe/token_dispatcher_inference.py | 141 +++++++++++++++--- 3 files changed, 154 insertions(+), 35 deletions(-) diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index c5a73600ee1..39c8dc7942f 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -139,8 +139,9 @@ # Memory buffers to avoid dynamic memory allocation _GLOBAL_MEMORY_BUFFER = None -# Global symmetric memory buffer for inference -_GLOBAL_SYMMETRIC_MEMORY_BUFFER = None +# Global symmetric memory buffers for inference +_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None +_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None # List of all process groups # Used for updating the timeout for all process groups @@ -2000,14 +2001,20 @@ def _set_global_memory_buffer(): def _set_global_symmetric_memory_buffer(): """Initialize global buffer.""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER is None, "global memory buffer is already initialized" + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is None, "global symmetric memory buffer for TP is already initialized" + assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is None, "global symmetric memory buffer for EP is already initialized" - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = GlobalSymmetricMemoryBuffer( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = GlobalSymmetricMemoryBuffer( size_in_mb=256, # todo: set from an argument? process_group=get_tensor_model_parallel_group(), ) + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = GlobalSymmetricMemoryBuffer( + size_in_mb=256, # todo: set from an argument? + process_group=get_expert_model_parallel_group(), + ) + def get_global_memory_buffer(): """Return the global GlobalMemoryBuffer object""" @@ -2015,12 +2022,19 @@ def get_global_memory_buffer(): return _GLOBAL_MEMORY_BUFFER -def get_global_symmetric_memory_buffer(): +def get_global_symmetric_memory_buffer_tp(): + """Return the global GlobalSymmetricMemoryBuffer object""" + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is not None + ), "global symmetric memory buffer is not initialized" + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP + +def get_global_symmetric_memory_buffer_ep(): """Return the global GlobalSymmetricMemoryBuffer object""" assert ( - _GLOBAL_SYMMETRIC_MEMORY_BUFFER is not None + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is not None ), "global symmetric memory buffer is not initialized" - return _GLOBAL_SYMMETRIC_MEMORY_BUFFER + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP def destroy_global_memory_buffer(): @@ -2031,8 +2045,9 @@ def destroy_global_memory_buffer(): def destroy_global_symmetric_memory_buffer(): """Sets the global symmetric memory buffer to None""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None def get_all_ranks(): @@ -2110,8 +2125,11 @@ def destroy_model_parallel(): global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None + + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None global _DATA_PARALLEL_GROUP_GLOO if ( diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index 9c1adbc6717..72b0cb8e93b 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -14,7 +14,7 @@ multimem_reduce_scatter, ) from megatron.core.model_parallel_config import ModelParallelConfig -from megatron.core.parallel_state import get_global_symmetric_memory_buffer +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_tp from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -103,7 +103,7 @@ def _maybe_allocate_symmetric_buffer(self, x: torch.Tensor): """ symm_mem_buffer_dims = list(x.size()) symm_mem_buffer_dims[0] *= self.tp_size - symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) return symm_mem_buffer @@ -223,7 +223,7 @@ def _matmul_reduce_scatter(self, x, residual=None): # 3. attempt to ask for symmetric memory symm_mem_buffer_dims = list(x.size()) symm_mem_buffer_dims[-1] = self.weight.size(0) - symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index bd040870fc8..fd36d74e0e9 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -5,6 +5,9 @@ This implementation keeps tokens_per_expert GPU-resident to enable use of torch._grouped_mm without host synchronization. + +Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) +on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. """ import torch @@ -16,11 +19,16 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.moe.inference_kernels import ( launch_fused_permute_and_probs, launch_unpermute_kernel, -) +) +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep +from megatron.core.inference.communication.torch_symm_triton import ( + multimem_all_gather, + multimem_reduce_scatter, +) import logging @@ -61,24 +69,74 @@ def __init__( ) self.topk = config.moe_router_topk + # Cache for NVLS eligibility + self._nvls_eligible = None + + def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: + """ + Check if we can use NVLS (latency-optimized) collectives. + Requirements: BF16 dtype, Hopper+ GPU (SM >= 9). + """ + is_bf16 = x.dtype == torch.bfloat16 + is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 + return is_bf16 and is_hopper_or_newer + + def _maybe_allocate_ag_buffer(self, x: torch.Tensor) -> dict: + """ + Allocate symmetric memory buffer for all-gather output. + Output shape: [local_tokens * ep_size, hidden_dim] + """ + ag_output_dims = list(x.size()) + ag_output_dims[0] *= self.ep_size + symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( + ag_output_dims, dtype=x.dtype + ) + return symm_mem_buffer + + def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: + """ + Allocate symmetric memory buffer for reduce-scatter input. + Input shape matches x (the unpermuted hidden states). + """ + symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( + list(x.size()), dtype=x.dtype + ) + return symm_mem_buffer + def token_dispatch(self, hidden_states, probs): - """Gathers tokens from all TP*EP ranks using AllGather.""" - - # Permute the tokens across the expert parallel devices. - if self.tp_size > 1 or self.ep_size > 1: - ## local_indices calculation - with torch.no_grad(): - # [num_local_tokens, num_experts] -> [num_global_tokens, num_experts], where: - # num_local_tokens=(S/TP)*B, num_global_tokens=S*B*EP - self.routing_map = gather_from_sequence_parallel_region( - self.routing_map, group=self.tp_ep_group - ) - - ## local_probs calculation - # max_prob: [S/TP*B, num_experts] -> global_probs: [S*B*EP, num_experts] - probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) - # Note that this allgather spans the communication domain of TP*EP. - # [(S/TP)*B, H] -> [((S/TP)*B)*(TP*EP), H] = [S*B*EP, H] + """ + Gathers tokens from all EP ranks using AllGather. + + Uses NCCL for routing_map and probs. + Uses latency-optimized NVLS multimem_all_gather for hidden_states on Hopper+ GPUs + with BF16 when symmetric memory is available, falls back to NCCL otherwise. + """ + if self.ep_size == 1: + return hidden_states, probs + + # All-gather routing_map and probs using NCCL + self.routing_map = gather_from_sequence_parallel_region( + self.routing_map, group=self.tp_ep_group + ) + + # [local_tokens, num_experts] -> [global_tokens, num_experts] + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) + + # All-gather hidden_states: try NVLS first, fallback to NCCL + nvls_eligible = self._check_nvls_eligibility(hidden_states) + ag_buffer = None + + if nvls_eligible: + ag_buffer = self._maybe_allocate_ag_buffer(hidden_states) + + can_use_nvls = nvls_eligible and ag_buffer["handle"] is not None + + if can_use_nvls: + # Use latency-optimized NVLS all-gather for hidden_states + multimem_all_gather(ag_buffer["tensor"], hidden_states, ag_buffer["handle"]) + hidden_states = ag_buffer["tensor"] + else: + # Fallback to NCCL for hidden_states hidden_states = gather_from_sequence_parallel_region( hidden_states, group=self.tp_ep_group ) @@ -182,7 +240,7 @@ def combine_preprocess(self, permuted_expert_outputs): Uses cached mask_T and dest_indices from dispatch_postprocess to avoid recomputing them (saves 2 kernel launches). """ - # 1. Pre-allocate static output buffer (zeros for atomic accumulation) + # 1. Pre-allocate output buffer w/ zeros. unpermuted_hidden = torch.zeros( self.hidden_shape_before_permute, dtype=permuted_expert_outputs.dtype, @@ -200,3 +258,46 @@ def combine_preprocess(self, permuted_expert_outputs): return unpermuted_hidden + def token_combine(self, hidden_states): + """ + Combines expert outputs using Reduce-Scatter. + + Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs with BF16 + when symmetric memory is available. Falls back to NCCL via superclass otherwise. + + Args: + hidden_states: [global_tokens, hidden_dim] tensor to reduce-scatter + + Returns: + [local_tokens, hidden_dim] tensor after reduce-scatter + """ + if self.ep_size == 1: + return hidden_states + + # Check NVLS eligibility and try to allocate symmetric memory + nvls_eligible = self._check_nvls_eligibility(hidden_states) + rs_buffer = None + + if nvls_eligible: + rs_buffer = self._maybe_allocate_rs_buffer(hidden_states) + + can_use_nvls = nvls_eligible and rs_buffer["handle"] is not None + + if can_use_nvls: + # Copy input to symmetric memory for reduce-scatter + rs_buffer["tensor"].copy_(hidden_states) + + # Allocate output tensor + output_shape = list(hidden_states.size()) + output_shape[0] = hidden_states.size(0) // self.ep_size + output = torch.empty( + output_shape, dtype=hidden_states.dtype, device=hidden_states.device + ) + + # Use latency-optimized NVLS reduce-scatter + multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) + return output + else: + # Fallback to NCCL via superclass + return super().token_combine(hidden_states) + From 02f315a0aaa60c32a53da577eb22927e91170fe8 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Feb 2026 04:44:10 -0800 Subject: [PATCH 17/92] nvls all gathers for all three tensors. nvls rs on hidden state --- .../torch_symm_triton/collectives.py | 15 +- .../moe/token_dispatcher_inference.py | 157 ++++++++++++++---- 2 files changed, 141 insertions(+), 31 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 4bc4dbde42b..24a463350fa 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -32,6 +32,7 @@ def _multimem_all_gather_kernel( multicast_ptr, signal_pad_ptrs, numel, + byte_offset, BLOCK_SIZE: tl.constexpr, NUMEL_PER_THREAD: tl.constexpr, RANK: tl.constexpr, @@ -39,6 +40,9 @@ def _multimem_all_gather_kernel( ): """ Triton kernel to perform multicast all-gather over nvlink using multimem instructions. + + Args: + byte_offset: Byte offset into the multicast buffer where this tensor starts. """ # an all-gather is simply a multicast store operation # we only need a barrier at the end to ensure visibility of writes @@ -56,10 +60,12 @@ def _multimem_all_gather_kernel( mask = offsets < numel_per_rank # Each pointer points to a 128-bit bit pack + # byte_offset // 8 -> converts byte offset to uint64 offset # RANK * numel_per_rank -> brings us to the start of our rank's segment # offsets -> brings us to the right offset within our rank's segment + # * 2 -> each 128-bit pack is 2 uint64s multicast_ptrs = ( - multicast_ptr.to(tl.pointer_type(tl.uint64)) + (RANK * numel_per_rank + offsets) * 2 + multicast_ptr.to(tl.pointer_type(tl.uint64)) + byte_offset // 8 + (RANK * numel_per_rank + offsets) * 2 ) local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) @@ -82,6 +88,7 @@ def multimem_all_gather( output_tensor: torch.Tensor, input_tensor: torch.Tensor, symm_mem_hdl: _SymmetricMemory, + byte_offset: int = 0, **kwargs, ) -> torch.Tensor: """ @@ -92,6 +99,7 @@ def multimem_all_gather( output_tensor: torch.Tensor - output tensor to be all-gathered into input_tensor: torch.Tensor - input tensor to be all-gathered from symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for output_tensor + byte_offset: int - byte offset into the multicast buffer where output_tensor starts Returns: torch.Tensor - all-gathered tensor, which is output_tensor """ @@ -102,8 +110,8 @@ def multimem_all_gather( "num_warps": kwargs.get("num_warps", 32), "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), } - assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + # assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + # assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." numel_per_thread = 128 // (input_tensor.element_size() * 8) assert ( @@ -118,6 +126,7 @@ def multimem_all_gather( symm_mem_hdl.multicast_ptr, symm_mem_hdl.signal_pad_ptrs_dev, numel=output_tensor.numel(), + byte_offset=byte_offset, BLOCK_SIZE=config["BLOCK_SIZE"], NUMEL_PER_THREAD=numel_per_thread, RANK=symm_mem_hdl.rank, diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index fd36d74e0e9..01bb9bdc0ec 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -81,17 +81,90 @@ def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 return is_bf16 and is_hopper_or_newer - def _maybe_allocate_ag_buffer(self, x: torch.Tensor) -> dict: + def _maybe_allocate_ag_buffers( + self, + routing_map: torch.Tensor, + probs: torch.Tensor, + hidden_states: torch.Tensor, + ) -> dict: """ - Allocate symmetric memory buffer for all-gather output. - Output shape: [local_tokens * ep_size, hidden_dim] + Allocate a single symmetric memory buffer for all-gather outputs of + routing_map, probs and hidden_states. Returns sliced views for each. + + All tensors are gathered from ep_size ranks, so output shapes are: + - routing_map: [local_tokens * ep_size, num_experts] + - probs: [local_tokens * ep_size, num_experts] + - hidden_states: [local_tokens * ep_size, hidden_dim] + + Returns dict with: + - "handle": symmetric memory handle (or None if unavailable) + - "routing_map": view for routing_map output + - "routing_map_offset": byte offset of routing_map in the symmetric buffer + - "probs": view for probs output + - "probs_offset": byte offset of probs in the symmetric buffer + - "hidden_states": view for hidden_states output + - "hidden_states_offset": byte offset of hidden_states in the symmetric buffer """ - ag_output_dims = list(x.size()) - ag_output_dims[0] *= self.ep_size - symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( - ag_output_dims, dtype=x.dtype - ) - return symm_mem_buffer + symm_buffer_mgr = get_global_symmetric_memory_buffer_ep() + if symm_buffer_mgr.symm_mem_hdl is None: + return { + "handle": None, + "routing_map": None, "routing_map_offset": 0, + "probs": None, "probs_offset": 0, + "hidden_states": None, "hidden_states_offset": 0, + } + + # Calculate output shapes after all-gather + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + num_experts = probs.size(1) + hidden_dim = hidden_states.size(1) + + # Calculate bytes needed for each tensor (with 16-byte alignment) + def aligned_bytes(numel, dtype): + elem_size = torch.tensor([], dtype=dtype).element_size() + raw_bytes = numel * elem_size + # Align to 16 bytes for 128-bit access + return ((raw_bytes + 15) // 16) * 16 + + routing_map_bytes = aligned_bytes(global_tokens * num_experts, routing_map.dtype) + probs_bytes = aligned_bytes(global_tokens * num_experts, probs.dtype) + hidden_states_bytes = aligned_bytes(global_tokens * hidden_dim, hidden_states.dtype) + total_bytes = routing_map_bytes + probs_bytes + hidden_states_bytes + + # Check if buffer has enough space + if total_bytes > symm_buffer_mgr.symm_buffer.numel(): + return { + "handle": None, + "routing_map": None, "routing_map_offset": 0, + "probs": None, "probs_offset": 0, + "hidden_states": None, "hidden_states_offset": 0, + } + + # Slice the raw buffer and create views, tracking byte offsets + # [routing_map_bytes | probs_bytes | hidden_states_bytes] + # offset=0 offset=rm offset=rm+probs + + raw_buffer = symm_buffer_mgr.symm_buffer + + routing_map_offset = 0 + routing_map_buffer = raw_buffer[routing_map_offset : routing_map_offset + routing_map_bytes] + + probs_offset = routing_map_bytes + probs_buffer = raw_buffer[probs_offset : probs_offset + probs_bytes] + + hidden_states_offset = probs_offset + probs_bytes + hidden_states_buffer = raw_buffer[hidden_states_offset : hidden_states_offset + hidden_states_bytes] + + return { + "handle": symm_buffer_mgr.symm_mem_hdl, + "routing_map": routing_map_buffer, + "routing_map_offset": routing_map_offset, + "probs": probs_buffer, + "probs_offset": probs_offset, + "hidden_states": hidden_states_buffer, + "hidden_states_offset": hidden_states_offset, + } def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: """ @@ -107,36 +180,64 @@ def token_dispatch(self, hidden_states, probs): """ Gathers tokens from all EP ranks using AllGather. - Uses NCCL for routing_map and probs. - Uses latency-optimized NVLS multimem_all_gather for hidden_states on Hopper+ GPUs - with BF16 when symmetric memory is available, falls back to NCCL otherwise. + Uses latency-optimized NVLS multimem_all_gather for routing_map, probs and hidden_states + on Hopper+ GPUs with BF16. Falls back to NCCL via superclass otherwise. """ if self.ep_size == 1: return hidden_states, probs - # All-gather routing_map and probs using NCCL - self.routing_map = gather_from_sequence_parallel_region( - self.routing_map, group=self.tp_ep_group - ) - - # [local_tokens, num_experts] -> [global_tokens, num_experts] - probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) - - # All-gather hidden_states: try NVLS first, fallback to NCCL + # Check NVLS eligibility nvls_eligible = self._check_nvls_eligibility(hidden_states) - ag_buffer = None + ag_buffers = None if nvls_eligible: - ag_buffer = self._maybe_allocate_ag_buffer(hidden_states) + ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) - can_use_nvls = nvls_eligible and ag_buffer["handle"] is not None + can_use_nvls = nvls_eligible and ag_buffers["handle"] is not None if can_use_nvls: - # Use latency-optimized NVLS all-gather for hidden_states - multimem_all_gather(ag_buffer["tensor"], hidden_states, ag_buffer["handle"]) - hidden_states = ag_buffer["tensor"] + # Capture shapes for reshaping after all-gather + # Output shape: [local_tokens * ep_size, dim] + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + num_experts = probs.size(1) + hidden_dim = hidden_states.size(1) + routing_map_dtype = self.routing_map.dtype + probs_dtype = probs.dtype + hidden_dtype = hidden_states.dtype + + # Use latency-optimized NVLS all-gather for routing_map, probs and hidden_states + # Pass byte_offset so kernel writes to correct location in multicast buffer + multimem_all_gather( + ag_buffers["routing_map"].view(torch.bfloat16), + self.routing_map.view(torch.bfloat16), + ag_buffers["handle"], + byte_offset=ag_buffers["routing_map_offset"], + ) + self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, num_experts) + + multimem_all_gather( + ag_buffers["probs"].view(torch.bfloat16), + probs.view(torch.bfloat16), + ag_buffers["handle"], + byte_offset=ag_buffers["probs_offset"], + ) + probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, num_experts) + + multimem_all_gather( + ag_buffers["hidden_states"].view(torch.bfloat16), + hidden_states.view(torch.bfloat16), + ag_buffers["handle"], + byte_offset=ag_buffers["hidden_states_offset"], + ) + hidden_states = ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) else: - # Fallback to NCCL for hidden_states + # Fallback to NCCL for all tensors + with torch.no_grad(): + self.routing_map = gather_from_sequence_parallel_region( + self.routing_map, group=self.tp_ep_group + ) + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) hidden_states = gather_from_sequence_parallel_region( hidden_states, group=self.tp_ep_group ) From 0fac929e23faaa8412b15da4e32360d333a09b3e Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Feb 2026 05:20:57 -0800 Subject: [PATCH 18/92] full model cg optimizations and bump up max blocks for blackwell --- .../communication/torch_symm_triton/collectives.py | 4 ++-- megatron/core/transformer/cuda_graphs.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 24a463350fa..9d48fc8b341 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -106,7 +106,7 @@ def multimem_all_gather( assert HAVE_TRITON, "Triton is required for multimem all-gather." config = { - "max_num_blocks": kwargs.get("max_num_blocks", 24), + "max_num_blocks": kwargs.get("max_num_blocks", 128), "num_warps": kwargs.get("num_warps", 32), "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), } @@ -209,7 +209,7 @@ def multimem_reduce_scatter( assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." config = { - "max_num_blocks": kwargs.get("max_num_blocks", 24), + "max_num_blocks": kwargs.get("max_num_blocks", 128), "num_warps": kwargs.get("num_warps", 32), "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), } diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 1e3e3edc558..3ae794967db 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -951,6 +951,12 @@ def record_graph_capture(self, args, kwargs): # issues, for instance with pipeline parallelism return tuple(o.clone() if torch.is_tensor(o) else o for o in out) + def _get_cached_parameters_set_for_inference(self): + """Return cached parameters for inference mode.""" + if not hasattr(self, '_cached_parameters_set_for_inference'): + self._cached_parameters_set_for_inference = tuple(self.parameters()) + return self._cached_parameters_set_for_inference + def replay_graph_capture(self, is_first_microbatch, args, kwargs): """Replay the fwd cuda graph with autograd.""" @@ -963,7 +969,11 @@ def replay_graph_capture(self, is_first_microbatch, args, kwargs): raise AssertionError(error_msg) inp_tensors = self.get_tensors(args, kwargs) - func_args = inp_tensors + tuple(self.parameters()) + is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context'] + if not is_inference_mode: + func_args = inp_tensors + tuple(self.parameters()) + else: + func_args = inp_tensors + self._get_cached_parameters_set_for_inference() out = _CudagraphReplayNode.apply(self, is_first_microbatch, *func_args) out = list(out) @@ -1338,7 +1348,6 @@ def __call__(self, megatron_module, args, kwargs): if 'inference_context' in kwargs.keys() and kwargs['inference_context']: # Inference generation mode creates graphs immediately runner = self.get_cudagraph_runner(megatron_module, args, kwargs) - runner.eval() if not runner.fwd_graph_recorded: # Reuse graph input-output buffers for inference @@ -1375,6 +1384,7 @@ def __call__(self, megatron_module, args, kwargs): _CudagraphGlobalRecord.cudagraph_inference_record.append( (runner, "fwd", args, kwargs) ) + runner = runner.eval() # Now replay the graph out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs) From 371043c39c024930a4a1acdc3a37cd9d6dbddf16 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 4 Feb 2026 06:51:39 -0800 Subject: [PATCH 19/92] fix full model CG for mamba --- .../text_generation_controller.py | 3 ++- megatron/core/ssm/mamba_block.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 92d1720fc55..62e3a57e0e4 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -525,7 +525,7 @@ def _dynamic_step_context_init( moe_pad_experts_for_cuda_graph_inference = ( self.model_config.moe_pad_experts_for_cuda_graph_inference ) - is_inference_optimized = inference_wrapper_config.transformer_impl == "inference_optimized" + is_inference_optimized = self.model_config.transformer_impl == "inference_optimized" if is_inference_optimized: assert not moe_pad_experts_for_cuda_graph_inference, ( "moe_pad_experts_for_cuda_graph_inference cannot be True when " @@ -537,6 +537,7 @@ def _dynamic_step_context_init( set_decode_expert_padding(unwrapped_model, True, capacity_factor=capacity_factor) else: set_decode_expert_padding(unwrapped_model, False) + if is_inference_optimized and model_config.expert_model_parallel_size > 1: set_is_cuda_graphed_iteration_for_ep_inference(unwrapped_model, context.using_cuda_graph_this_step()) diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 3d684b82dce..6ff7001d63c 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -25,7 +25,7 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.module import MegatronModule, GraphableMegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.transformer.utils import sharded_state_dict_default @@ -44,7 +44,7 @@ class MambaStackSubmodules: moe_layer: Union[ModuleSpec, type] = IdentityOp -class MambaStack(MegatronModule): +class MambaStack(GraphableMegatronModule, MegatronModule): """ Constructor for the MambaStack class. @@ -231,6 +231,7 @@ def __call__(self, *args, **kwargs): if isinstance(kwargs['hidden_states'], WrappedTensor) else kwargs['hidden_states'] ) + return super().__call__(*args, **kwargs)[0] return super().__call__(*args, **kwargs) def forward( From 01cd40fc2b3254813b0c77674f6c804cab71990f Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 4 Feb 2026 07:13:56 -0800 Subject: [PATCH 20/92] remove requirement for moe permute fusion --- .../core/transformer/transformer_config.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 06e99016e81..e871b504898 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1075,17 +1075,17 @@ def __post_init__(self): "(moe_expert_capacity_factor=None and moe_router_padding_for_quantization=False). " ) - if self.transformer_impl == "inference_optimized" and self.num_moe_experts is not None: - if not self.moe_permute_fusion: - raise ValueError( - "Inference-optimized MoE layers require moe_permute_fusion=True " - "to use TE fused kernels that support GPU-resident metadata." - ) - # if not self.moe_router_fusion: - # raise ValueError( - # "Inference-optimized MoE layers require moe_router_fusion=True " - # "to use TE fused router kernels." - # ) + # if self.transformer_impl == "inference_optimized" and self.num_moe_experts is not None: + # if not self.moe_permute_fusion: + # raise ValueError( + # "Inference-optimized MoE layers require moe_permute_fusion=True " + # "to use TE fused kernels that support GPU-resident metadata." + # ) + # # if not self.moe_router_fusion: + # # raise ValueError( + # # "Inference-optimized MoE layers require moe_router_fusion=True " + # # "to use TE fused router kernels." + # # ) if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError("num_moe_experts must be non-negative.") From 30d8cf35a2f264a2872526b1645973826f090e56 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 4 Feb 2026 12:13:28 -0800 Subject: [PATCH 21/92] failed attempt at optimizing router and permute --- megatron/core/models/gpt/moe_module_specs.py | 23 +- .../core/transformer/moe/inference_kernels.py | 484 ++++++++++++------ .../transformer/moe/moe_layer_inference.py | 1 + megatron/core/transformer/moe/router.py | 82 ++- .../moe/token_dispatcher_inference.py | 185 +++---- 5 files changed, 497 insertions(+), 278 deletions(-) diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 7f9fc211552..3b02e19bb2d 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -9,6 +9,7 @@ from megatron.core.transformer.moe.moe_layer_inference import InferenceMoELayer from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.moe.router import InferenceTopKRouter def get_moe_module_spec( @@ -86,14 +87,18 @@ def get_moe_module_spec_for_backend( # Select MoE layer class based on inference_optimized flag if inference_optimized: - moe_layer_class = InferenceMoELayer + moe_module_spec = ModuleSpec( + module=InferenceMoELayer, + submodules=MoESubmodules(router=InferenceTopKRouter, + experts=experts, + shared_experts=shared_experts), + metainfo={"fuse_pre_mlp_layernorm": False}, + ) else: - moe_layer_class = MoELayer - - # MoE module spec - moe_module_spec = ModuleSpec( - module=moe_layer_class, - submodules=MoESubmodules(experts=experts, shared_experts=shared_experts), - metainfo={"fuse_pre_mlp_layernorm": False}, - ) + # MoE module spec + moe_module_spec = ModuleSpec( + module=MoELayer, + submodules=MoESubmodules(experts=experts, shared_experts=shared_experts), + metainfo={"fuse_pre_mlp_layernorm": False}, + ) return moe_module_spec diff --git a/megatron/core/transformer/moe/inference_kernels.py b/megatron/core/transformer/moe/inference_kernels.py index 352f1442b60..430482b55a2 100644 --- a/megatron/core/transformer/moe/inference_kernels.py +++ b/megatron/core/transformer/moe/inference_kernels.py @@ -1,200 +1,362 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +Triton kernels for MoE inference optimizations. +""" + import torch import triton import triton.language as tl -import pytest -import torch @triton.jit -def moe_unpermute_kernel( - permuted_ptr, mask_ptr_T, dest_idx_ptr, output_ptr, - stride_out_t, num_tokens, hidden_size, +def shift_and_mark_indices_kernel( + topk_indices_ptr, # Input: [num_tokens, topk] + shifted_indices_ptr, # Output: [num_tokens, topk] + num_tokens: tl.constexpr, + topk: tl.constexpr, + local_start: tl.constexpr, # First local expert index + local_end: tl.constexpr, # Last local expert index + sentinel: tl.constexpr, # num_local_experts (sentinel for invalid) BLOCK_SIZE: tl.constexpr, ): + """ + Shifts topk indices to local coordinate system and marks invalid indices. + + For each index: + - If index in [local_start, local_end]: shift to 0-based (index - local_start) + - Otherwise: mark as sentinel value + """ + # Each program handles one block of elements pid = tl.program_id(0) - if not tl.load(mask_ptr_T + pid): return - src_row_idx = tl.load(dest_idx_ptr + pid) - 1 - token_idx = pid % num_tokens - - offsets = tl.arange(0, BLOCK_SIZE) - mask_h = offsets < hidden_size - - # Load as current dtype - row_data = tl.load(permuted_ptr + (src_row_idx * hidden_size) + offsets, mask=mask_h) + # Calculate total elements + num_elements = num_tokens * topk + + # Process BLOCK_SIZE elements per program + offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offset < num_elements + + # Load indices + indices = tl.load(topk_indices_ptr + offset, mask=mask, other=0) + + # Check if index is in local range + is_valid = (indices >= local_start) & (indices <= local_end) - # Cast to float32 for the accumulation to avoid BF16 rounding errors - row_data_fp32 = row_data.to(tl.float32) + # Shift valid indices, mark invalid with sentinel + shifted = tl.where(is_valid, indices - local_start, sentinel) - # Atomic add in FP32 (Triton handles the casting/locking) - tl.atomic_add(output_ptr + (token_idx * stride_out_t) + offsets, row_data_fp32, mask=mask_h) + # Store result + tl.store(shifted_indices_ptr + offset, shifted, mask=mask) -def launch_unpermute_kernel(hidden_states, static_buffer, mask_T, dest_indices): - """ - Launch the unpermute kernel. +def shift_topk_indices( + topk_indices: torch.Tensor, + local_start: int, + local_end: int, + num_local_experts: int, +) -> torch.Tensor: + """ + Shift topk indices to local coordinate system using Triton kernel. + Args: - hidden_states: [T, H] output tensor to accumulate into (should be zeroed) - static_buffer: [max_out, H] permuted expert outputs - mask_T: [E, T] pre-transposed mask (reused from dispatch) - dest_indices: [E*T] cumsum indices (reused from dispatch) + topk_indices: [num_tokens, topk] tensor of expert indices + local_start: First local expert global index + local_end: Last local expert global index + num_local_experts: Number of local experts + + Returns: + shifted_indices: [num_tokens, topk] with local indices or sentinel """ - T, H = hidden_states.shape - E = mask_T.size(0) # mask_T is [E, T] - - grid = (E * T,) - BLOCK_SIZE = triton.next_power_of_2(H) - - # For unpermute, hidden_states is the 'output' we write back into - # ensure that hidden states is zeroed out before accumulation - moe_unpermute_kernel[grid]( - static_buffer, mask_T, dest_indices, hidden_states, - hidden_states.stride(0), T, H, BLOCK_SIZE=BLOCK_SIZE + num_tokens, topk = topk_indices.shape + shifted_indices = torch.empty_like(topk_indices) + + num_elements = num_tokens * topk + BLOCK_SIZE = 1024 + grid = lambda meta: (triton.cdiv(num_elements, meta['BLOCK_SIZE']),) + + shift_and_mark_indices_kernel[grid]( + topk_indices, + shifted_indices, + num_tokens=num_tokens, + topk=topk, + local_start=local_start, + local_end=local_end, + sentinel=num_local_experts, + BLOCK_SIZE=BLOCK_SIZE, ) - + + return shifted_indices @triton.jit -def moe_fused_permute_extract_kernel( - hidden_ptr, probs_ptr, mask_ptr_T, dest_idx_ptr, - out_hidden_ptr, out_probs_ptr, - stride_h_t, stride_probs_t, stride_probs_e, - num_tokens, hidden_size, +def permute_and_count_kernel( + # Input tensors + hidden_states_ptr, # [num_tokens, hidden_dim] + probs_ptr, # [num_tokens, topk] + expert_assignments_ptr, # [num_tokens * topk] - local expert index per token-k pair + permutation_ptr, # [num_tokens * topk] - argsort result + # Output tensors + permuted_hidden_ptr, # [max_out, hidden_dim] + permuted_probs_ptr, # [max_out] + tokens_per_expert_ptr, # [num_local_experts] + # Scalars + num_tokens: tl.constexpr, + topk: tl.constexpr, + hidden_dim: tl.constexpr, + num_local_experts: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): """ - Fused kernel: permute hidden states AND extract probs in one pass. - - This kernel avoids transposing probs by using stride-based indexing. - The mask is expected to be pre-transposed [E, T] for efficient expert-major access. + Permute hidden states and probs according to permutation, count tokens per expert. + + Each program handles one output position. Skips sentinel values (expert == num_local_experts). """ + # Each program handles one output position pid = tl.program_id(0) - - # Early exit if this expert-token pair is inactive - if not tl.load(mask_ptr_T + pid): + + # Total elements to process + num_elements = num_tokens * topk + + if pid >= num_elements: return - - out_row_idx = tl.load(dest_idx_ptr + pid) - 1 - token_idx = pid % num_tokens - expert_idx = pid // num_tokens - - # 1. Permute hidden states (vectorized load/store) - offsets = tl.arange(0, BLOCK_SIZE) - mask_h = offsets < hidden_size - row_data = tl.load(hidden_ptr + (token_idx * stride_h_t) + offsets, mask=mask_h) - tl.store(out_hidden_ptr + (out_row_idx * hidden_size) + offsets, row_data, mask=mask_h) - - # 2. Extract probability using stride-based indexing (avoids probs transpose) - # probs is [T, E], so index as probs[token_idx, expert_idx] - prob = tl.load(probs_ptr + token_idx * stride_probs_t + expert_idx * stride_probs_e) - tl.store(out_probs_ptr + out_row_idx, prob) - - -def launch_fused_permute_and_probs(hidden_states, probs, mask_T, - hidden_workspace, probs_workspace): + + # Load the permutation index - where to read from + perm_idx = tl.load(permutation_ptr + pid) + + # Load the expert index for this position + expert_idx = tl.load(expert_assignments_ptr + perm_idx) + + # Skip if this is a sentinel value + if expert_idx >= num_local_experts: + return + + # Compute source token and k indices + # perm_idx tells us position in flattened [num_tokens * topk] array + token_idx = perm_idx // topk + k_idx = perm_idx % topk + + # Copy hidden state: load from [token_idx, :] and store to [pid, :] + for d in range(0, hidden_dim, BLOCK_SIZE): + offset = d + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_dim + + hidden_val = tl.load( + hidden_states_ptr + token_idx * hidden_dim + offset, + mask=mask, + other=0.0 + ) + tl.store( + permuted_hidden_ptr + pid * hidden_dim + offset, + hidden_val, + mask=mask + ) + + # Copy prob: load from [token_idx, k_idx] + prob_val = tl.load(probs_ptr + token_idx * topk + k_idx) + tl.store(permuted_probs_ptr + pid, prob_val) + + # Atomically increment tokens_per_expert[expert_idx] + tl.atomic_add(tokens_per_expert_ptr + expert_idx, 1) + + +def permute_tokens_and_probs( + hidden_states: torch.Tensor, + probs: torch.Tensor, + expert_assignments: torch.Tensor, + permutation: torch.Tensor, + num_local_experts: int, + max_tokens: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Fused launcher that: - 1. Accepts pre-transposed mask_T [E, T] (caller fuses slice+transpose) - 2. Uses stride-based probs access (no transpose needed) - 3. Launches a single fused kernel for both permute + prob extraction - + Permute hidden states and probs, count tokens per expert using Triton kernel. + Args: - hidden_states: [T, H] input hidden states - probs: [T, E] routing probabilities (NOT transposed) - mask_T: [E, T] pre-transposed routing mask (caller provides this) - hidden_workspace: [max_out, H] output buffer for permuted hidden states - probs_workspace: [max_out] output buffer for extracted probs - + hidden_states: [num_tokens, hidden_dim] + probs: [num_tokens, topk] + expert_assignments: [num_tokens * topk] local expert index per token-k pair + permutation: [num_tokens * topk] argsort result + num_local_experts: Number of local experts + max_tokens: Maximum output size + Returns: - dest_indices: Cumsum indices (cached for potential reuse in unpermute) + permuted_hidden: [max_tokens, hidden_dim] + permuted_probs: [max_tokens] + tokens_per_expert: [num_local_experts] """ - T, H = hidden_states.shape - E = mask_T.size(0) # mask_T is [E, T] - - # Only compute dest_indices (mask_T is provided by caller) - dest_indices = torch.cumsum(mask_T.view(-1).long(), dim=0).to(torch.int32) - - grid = (E * T,) - BLOCK_SIZE = triton.next_power_of_2(H) - - moe_fused_permute_extract_kernel[grid]( - hidden_states, probs, mask_T, dest_indices, - hidden_workspace, probs_workspace, - hidden_states.stride(0), probs.stride(0), probs.stride(1), - T, H, BLOCK_SIZE=BLOCK_SIZE + num_tokens, hidden_dim = hidden_states.shape + topk = probs.size(1) + + # Allocate outputs + permuted_hidden = torch.empty( + (max_tokens, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device ) - - # Return cached dest_indices for potential reuse in unpermute - return dest_indices - + permuted_probs = torch.empty( + max_tokens, + dtype=probs.dtype, + device=probs.device + ) + tokens_per_expert = torch.zeros( + num_local_experts, + dtype=torch.int32, + device=hidden_states.device + ) + + # Launch kernel - one program per output position + num_elements = num_tokens * topk + + # Adapt BLOCK_SIZE to hidden_dim for optimal memory access + # Use next power of 2 for better vectorization + BLOCK_SIZE = triton.next_power_of_2(hidden_dim) + # Cap at reasonable maximum to avoid register pressure + BLOCK_SIZE = min(BLOCK_SIZE, 2048) + + grid = (num_elements,) + + permute_and_count_kernel[grid]( + hidden_states, + probs, + expert_assignments, + permutation, + permuted_hidden, + permuted_probs, + tokens_per_expert, + num_tokens=num_tokens, + topk=topk, + hidden_dim=hidden_dim, + num_local_experts=num_local_experts, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return permuted_hidden, permuted_probs, tokens_per_expert -@pytest.mark.parametrize("T, E, H", [ - (1, 1, 128), # Minimal case - (64, 8, 512), # Standard small - (128, 16, 1024), # Medium - (256, 32, 2048), # Large (LLM Scale) - (1024, 1, 128), # Single Expert - (32, 64, 64), # High expert count -]) -@pytest.mark.parametrize("sparsity", [0.1, 0.5, 0.9]) -@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) -def test_fused_permute_and_probs(T, E, H, sparsity, dtype): +@triton.jit +def unpermute_and_combine_kernel( + # Input tensors + permuted_hidden_ptr, # [max_out, hidden_dim] - expert outputs + permutation_ptr, # [num_tokens * topk] - argsort result (forward permutation) + expert_assignments_ptr, # [num_tokens * topk] - local expert index per token-k pair + # Output tensor + output_ptr, # [num_tokens, hidden_dim] - unpermuted output + # Scalars + num_tokens: tl.constexpr, + topk: tl.constexpr, + hidden_dim: tl.constexpr, + num_local_experts: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + Unpermute expert outputs back to original token positions. + + Each program handles one position in the permutation array: + - Loads permutation[pid] to find source flat_pos (token_idx, k_idx) + - Loads expert output from permuted arrays at position pid + - Atomically accumulates output into output[token_idx] + + Note: Probability weighting is handled by the experts (via moe_apply_probs_on_input), + so this kernel only does unpermutation and accumulation. + """ + # Each program handles one permuted position + pid = tl.program_id(0) + + num_elements = num_tokens * topk + if pid >= num_elements: + return + + # Load source position from permutation + flat_pos = tl.load(permutation_ptr + pid) + + # Compute source token index + token_idx = flat_pos // topk + + # Load expert index to check validity + expert_idx = tl.load(expert_assignments_ptr + flat_pos) + + # Skip if sentinel (not a valid local expert) + if expert_idx >= num_local_experts: + return + + # Process each dimension chunk + for d in range(0, hidden_dim, BLOCK_SIZE): + offset = d + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_dim + + # Load expert output (already weighted by experts if configured) + hidden_val = tl.load( + permuted_hidden_ptr + pid * hidden_dim + offset, + mask=mask, + other=0.0 + ) + + # Atomically accumulate into output[token_idx] + tl.atomic_add(output_ptr + token_idx * hidden_dim + offset, hidden_val, mask=mask) + + +def unpermute_and_combine( + permuted_hidden: torch.Tensor, + expert_assignments: torch.Tensor, + permutation: torch.Tensor, + num_tokens: int, + topk: int, + num_local_experts: int, +) -> torch.Tensor: """ - Test that the fused kernel produces identical results to reference implementation. + Unpermute expert outputs back to original token order. + + Args: + permuted_hidden: [max_out, hidden_dim] expert outputs (already weighted by experts) + expert_assignments: [num_tokens * topk] local expert index per token-k pair + permutation: [num_tokens * topk] argsort result from dispatch + num_tokens: Number of original tokens + topk: Number of experts per token + num_local_experts: Number of local experts + + Returns: + output: [num_tokens, hidden_dim] unpermuted output + + Note: The expert outputs should already be weighted by routing probabilities + if moe_apply_probs_on_input is enabled in the config. """ - device = "cuda" - MAX_OUT = T * E - - # Setup inputs - hidden_states = torch.randn(T, H, device=device, dtype=dtype) * 1e-3 - probs = torch.rand(T, E, device=device, dtype=dtype) - mask = torch.rand(T, E, device=device) > sparsity - - # Ensure at least one active token-expert pair - if not mask.any(): - mask[0, 0] = True - - # Pre-transpose mask (simulating the fused slice+transpose in dispatcher) - mask_T = mask.t().contiguous() # [E, T] - - # --- Reference: Python-based verification --- - num_active = int(mask.sum().item()) - ref_hidden_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) - ref_probs_buffer = torch.zeros(MAX_OUT, device=device, dtype=dtype) - - # Expert-major ordering reference - buffer_idx = 0 - for e_idx in range(E): - for t_idx in range(T): - if mask[t_idx, e_idx]: - ref_hidden_buffer[buffer_idx] = hidden_states[t_idx] - ref_probs_buffer[buffer_idx] = probs[t_idx, e_idx] - buffer_idx += 1 - - # --- Test: Fused kernel launch --- - fused_hidden_buffer = torch.zeros((MAX_OUT, H), device=device, dtype=dtype) - fused_probs_buffer = torch.zeros(MAX_OUT, device=device, dtype=dtype) - - dest_indices = launch_fused_permute_and_probs( - hidden_states, probs, mask_T, - fused_hidden_buffer, fused_probs_buffer + hidden_dim = permuted_hidden.size(1) + + # Allocate output (zeroed for atomic accumulation) + output = torch.zeros( + (num_tokens, hidden_dim), + dtype=permuted_hidden.dtype, + device=permuted_hidden.device ) - - # --- Verify outputs match --- - # Compare hidden states (only active portion) - torch.testing.assert_close( - fused_hidden_buffer[:num_active], - ref_hidden_buffer[:num_active], - rtol=1e-5, atol=1e-5 + + # Adapt BLOCK_SIZE to hidden_dim + BLOCK_SIZE = triton.next_power_of_2(hidden_dim) + BLOCK_SIZE = min(BLOCK_SIZE, 2048) + + # Launch kernel - one program per permuted position (same pattern as permute kernel) + num_elements = num_tokens * topk + grid = (num_elements,) + + unpermute_and_combine_kernel[grid]( + permuted_hidden, + permutation, + expert_assignments, + output, + num_tokens=num_tokens, + topk=topk, + hidden_dim=hidden_dim, + num_local_experts=num_local_experts, + BLOCK_SIZE=BLOCK_SIZE, ) + + return output + + +def launch_fused_permute_and_probs(*args, **kwargs): + """Placeholder for future fused permute kernel.""" + raise NotImplementedError("launch_fused_permute_and_probs not yet implemented") - # Compare probs (only active portion) - torch.testing.assert_close( - fused_probs_buffer[:num_active], - ref_probs_buffer[:num_active], - rtol=1e-5, atol=1e-5 - ) - # Verify dest_indices shape - assert dest_indices.shape == (E * T,), f"dest_indices shape mismatch: {dest_indices.shape}" \ No newline at end of file +def launch_unpermute_kernel(*args, **kwargs): + """Placeholder for future unpermute kernel.""" + raise NotImplementedError("launch_unpermute_kernel not yet implemented") diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index 3a140325dd1..2991369b2b9 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -107,6 +107,7 @@ def __init__( ) def set_is_cuda_graphed_iteration(self, set_to): self.is_cuda_graphed_iteration = set_to + self.router.set_is_cuda_graphed_iteration(set_to) def activate_inference_token_dispatcher(self): # replace the token dispatcher with the inference-optimized version diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 4be97401748..0730d13a49f 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -23,7 +23,7 @@ ) from megatron.core.transformer.moe.router_replay import RouterReplay from megatron.core.transformer.transformer_config import TransformerConfig - +import logging class Router(ABC, MegatronModule): """Base Router class""" @@ -669,3 +669,83 @@ def _save_to_state_dict(self, *args, **kwargs): """Save the state dict of the router.""" self._maintain_float32_expert_bias() # switch to float32 before saving return super()._save_to_state_dict(*args, **kwargs) + + +class InferenceTopKRouter(TopKRouter): + """Specialized top-k router optimized for inference with specific constraints. + + This router enforces: + - moe_router_num_groups: None (no group-limited routing) + - moe_router_score_function: sigmoid + - moe_router_enable_expert_bias: True + """ + + def __init__( + self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None + ) -> None: + """Initialize the specialized inference top-k router. + + Args: + config (TransformerConfig): The configuration for the transformer model. + pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations. + """ + # Enforce constraints before calling super().__init__ + assert ( + config.moe_router_num_groups is None + ), f"InferenceTopKRouter requires moe_router_num_groups=None, got {config.moe_router_num_groups}" + assert ( + config.moe_router_score_function == "sigmoid" + ), f"InferenceTopKRouter requires moe_router_score_function='sigmoid', got '{config.moe_router_score_function}'" + assert ( + config.moe_router_enable_expert_bias is True + ), f"InferenceTopKRouter requires moe_router_enable_expert_bias=True, got {config.moe_router_enable_expert_bias}" + + super().__init__(config=config, pg_collection=pg_collection) + + self.is_cuda_graphed_iteration = False + + def set_is_cuda_graphed_iteration(self, set_to: bool): + self.is_cuda_graphed_iteration = set_to + + def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + logits = self.gating(input) # [num_tokens, num_experts] + + # Apply sigmoid to get independent scores per expert + scores = torch.sigmoid(logits.float()).type_as(logits) # [num_tokens, num_experts] + + # Add expert bias for topk selection (helps with load balancing) + scores_for_routing = scores + self.expert_bias # [num_experts] broadcasted + + # Select top-k experts based on biased scores + _, topk_indices = torch.topk(scores_for_routing, k=self.topk, dim=-1) # [num_tokens, topk] + + # Gather the original sigmoid scores (without bias) for selected experts + topk_probs = torch.gather(scores, dim=-1, index=topk_indices) # [num_tokens, topk] + + # Normalize to get routing probabilities (sum to 1 per token) + if self.topk > 1: + topk_probs = topk_probs / (topk_probs.sum(dim=-1, keepdim=True) + 1e-20) + + # NOTE: Return format differs from parent class for efficiency: + # - Parent: Returns sparse tensors [num_tokens, num_experts] (routing_probs, routing_map) + # - This: Returns dense tensors [num_tokens, topk] (topk_probs, topk_indices) + return topk_probs.squeeze(1), topk_indices.squeeze(1) + + def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + """Simplified forward pass for inference - returns dense tensors only. + + Args: + input (torch.Tensor): Input tensor of shape [seq_length, bsz, hidden_size]. + padding_mask (torch.Tensor, optional): Not used in inference. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - probs: Normalized routing probabilities [num_tokens, topk] + - top_indices: Selected expert indices [num_tokens, topk] + """ + # Compute logits via gating network + + if not self.is_cuda_graphed_iteration: + return super().forward(input, padding_mask) + + return self._forward(input, padding_mask) \ No newline at end of file diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 01bb9bdc0ec..78e751bd472 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -20,8 +20,9 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.moe.inference_kernels import ( - launch_fused_permute_and_probs, - launch_unpermute_kernel, + shift_topk_indices, + permute_tokens_and_probs, + unpermute_and_combine, ) from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep @@ -117,8 +118,8 @@ def _maybe_allocate_ag_buffers( # Calculate output shapes after all-gather local_tokens = probs.size(0) global_tokens = local_tokens * self.ep_size - num_experts = probs.size(1) - hidden_dim = hidden_states.size(1) + topk = probs.size(-1) + hidden_dim = hidden_states.size(-1) # Calculate bytes needed for each tensor (with 16-byte alignment) def aligned_bytes(numel, dtype): @@ -127,8 +128,8 @@ def aligned_bytes(numel, dtype): # Align to 16 bytes for 128-bit access return ((raw_bytes + 15) // 16) * 16 - routing_map_bytes = aligned_bytes(global_tokens * num_experts, routing_map.dtype) - probs_bytes = aligned_bytes(global_tokens * num_experts, probs.dtype) + routing_map_bytes = aligned_bytes(global_tokens * topk, routing_map.dtype) + probs_bytes = aligned_bytes(global_tokens * topk, probs.dtype) hidden_states_bytes = aligned_bytes(global_tokens * hidden_dim, hidden_states.dtype) total_bytes = routing_map_bytes + probs_bytes + hidden_states_bytes @@ -200,7 +201,7 @@ def token_dispatch(self, hidden_states, probs): # Output shape: [local_tokens * ep_size, dim] local_tokens = probs.size(0) global_tokens = local_tokens * self.ep_size - num_experts = probs.size(1) + topk = probs.size(1) hidden_dim = hidden_states.size(1) routing_map_dtype = self.routing_map.dtype probs_dtype = probs.dtype @@ -214,7 +215,7 @@ def token_dispatch(self, hidden_states, probs): ag_buffers["handle"], byte_offset=ag_buffers["routing_map_offset"], ) - self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, num_experts) + self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) multimem_all_gather( ag_buffers["probs"].view(torch.bfloat16), @@ -222,7 +223,7 @@ def token_dispatch(self, hidden_states, probs): ag_buffers["handle"], byte_offset=ag_buffers["probs_offset"], ) - probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, num_experts) + probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) multimem_all_gather( ag_buffers["hidden_states"].view(torch.bfloat16), @@ -244,120 +245,90 @@ def token_dispatch(self, hidden_states, probs): return hidden_states, probs - def test_permute_output(self, hidden_states, permute_output, mask): - # Verification of Grouped-by-Expert layout - E = self.local_map.size(1) - T = hidden_states.size(0) - mask = self.local_map - buffer_idx = 0 - for e_idx in range(E): - for t_idx in range(T): - if mask[t_idx, e_idx]: - assert torch.allclose(permute_output[buffer_idx], hidden_states[t_idx]) - buffer_idx += 1 - - #assert static_buffer[buffer_idx:].sum() == 0, "Stale data found in buffer tail" - - def test_permute_probs_output(self, local_probs, probs_workspace, mask): - """ - Verification of Grouped-by-Expert layout for probabilities. - local_probs: [Tokens, Experts] - probs_workspace: [MAX_OUT, 1] (or [MAX_OUT]) - mask: [Tokens, Experts] boolean mask - """ - T = local_probs.size(0) - E = local_probs.size(1) - - buffer_idx = 0 - # Expert-major traversal (Outer loop: Experts, Inner loop: Tokens) - for e_idx in range(E): - for t_idx in range(T): - if mask[t_idx, e_idx]: - # Extract the expected probability from the source [Tokens, Experts] - expected_prob = local_probs[t_idx, e_idx] - # Using a slightly relaxed atol for BF16 if necessary - actual_prob = probs_workspace[buffer_idx] - assert torch.allclose( - actual_prob, - expected_prob - ), f"Prob mismatch at buffer index {buffer_idx} (Expert {e_idx}, Token {t_idx})" - - buffer_idx += 1 def dispatch_postprocess(self, hidden_states, probs): """After gathering in token_dispatch, this method identifies tokens for local experts and permutes them for expert processing. - Optimized to: - 1. Fuse slice + transpose for mask (single kernel instead of two) - 2. Use stride-based probs access in kernel (avoids probs transpose entirely) - 3. Permute hidden states AND extract probs in a single kernel launch + Algorithm: + 1. Filter topk_indices to keep only those for local experts + 2. Shift valid indices to local coordinate system (0-indexed) + 3. Mark invalid indices with sentinel value + 4. Argsort to get token permutation that groups by expert + 5. Permute tokens and probs using this map + 6. Bincount to get tokens per expert """ self.hidden_shape_before_permute = hidden_states.shape - - # Fuse slice + transpose for mask: [T, num_experts] -> [num_local_experts, T] - # This produces mask_T directly, avoiding a separate transpose kernel - self._cached_mask_T = self.routing_map[ - :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 - ].t().contiguous() # [E, T] layout - - # Probs: just slice, no transpose needed (kernel uses stride-based access) - local_probs = probs[ - :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 - ].contiguous() # [T, E] layout - - # tokens_per_expert from transposed mask: sum over tokens (dim=1) for each expert - tokens_per_expert = self._cached_mask_T.sum(dim=1) - - # Pre-allocate workspaces - max_out = hidden_states.size(0) * min(self.topk, self.num_local_experts) - tokens_workspace = torch.empty( - max_out, hidden_states.size(1), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - probs_workspace = torch.empty( - max_out, - dtype=probs.dtype, - device=probs.device, + + # self.routing_map is actually topk_indices: [num_tokens, topk] + topk_indices = self.routing_map + num_tokens, topk = topk_indices.shape + + # Shift global expert indices to local coordinate system using Triton kernel. + # For each index: + # - If in range [local_expert_indices[0], local_expert_indices[-1]]: + # shift to 0-based (e.g., expert 4 -> 0 if local_expert_indices[0] == 4) + # - Otherwise: mark with sentinel value (num_local_experts) + # This prepares indices for argsort which will group tokens by local expert. + # Result: [num_tokens, topk] with local indices or sentinels + adjusted_topk_indices = shift_topk_indices( + topk_indices, + local_start=self.local_expert_indices[0], + local_end=self.local_expert_indices[-1], + num_local_experts=self.num_local_experts, ) - - # Fused kernel launch: permute hidden states + extract probs in one pass - # Pass mask_T directly (already transposed), probs as [T, E] (kernel uses strides) - self._cached_dest_indices = launch_fused_permute_and_probs( - hidden_states, local_probs, self._cached_mask_T, - tokens_workspace, probs_workspace + + # Flatten and argsort to get permutation that groups tokens by expert. + # After argsort, all tokens for expert 0 are at the beginning, then expert 1, etc. + # Sentinel values (num_local_experts) sort to the end. + # flat_indices: [num_tokens * topk] + # permutation: [num_tokens * topk] indices for reordering + flat_indices = adjusted_topk_indices.flatten() + self._cached_permutation = torch.argsort(flat_indices, stable=True) + + # Allocate workspace for permuted outputs + # Max possible tokens = num_tokens * min(topk, num_local_experts) + max_tokens = num_tokens * min(topk, self.num_local_experts) + + # Permute tokens and probs, count tokens per expert using Triton kernel + # Each thread block handles one output position, skipping sentinels + permuted_hidden, permuted_probs, tokens_per_expert = permute_tokens_and_probs( + hidden_states, + probs, + flat_indices, + self._cached_permutation, + num_local_experts=self.num_local_experts, + max_tokens=max_tokens, ) - + + # Cache data needed for unpermute in combine_preprocess + self._cached_flat_indices = flat_indices + self._cached_num_tokens = num_tokens + self._cached_topk = topk + self.routing_map = None - self.local_probs = probs_workspace - return tokens_workspace, tokens_per_expert, probs_workspace - + self.local_probs = permuted_probs + return permuted_hidden, tokens_per_expert, permuted_probs + + def combine_preprocess(self, permuted_expert_outputs): """ Reverses token permutation to restore original ordering. - Handles Top-K summation into original hidden state positions. - Uses cached mask_T and dest_indices from dispatch_postprocess to avoid - recomputing them (saves 2 kernel launches). + Uses cached permutation and expert_assignments from dispatch_postprocess. + Note: Probability weighting is handled by experts via moe_apply_probs_on_input. """ - # 1. Pre-allocate output buffer w/ zeros. - unpermuted_hidden = torch.zeros( - self.hidden_shape_before_permute, - dtype=permuted_expert_outputs.dtype, - device=permuted_expert_outputs.device - ) - - # 2. Launch the Un-permute kernel with cached intermediates - # It handles the Expert-grouped -> Token-major transition. - launch_unpermute_kernel( - unpermuted_hidden, # The [T, H] destination - permuted_expert_outputs, # The [max_out, H] source - self._cached_mask_T, # Cached [E, T] mask - self._cached_dest_indices # Cached cumsum indices + # Unpermute expert outputs using cached data + output = unpermute_and_combine( + permuted_expert_outputs, + self._cached_flat_indices, + self._cached_permutation, + num_tokens=self._cached_num_tokens, + topk=self._cached_topk, + num_local_experts=self.num_local_experts, ) - - return unpermuted_hidden + + return output def token_combine(self, hidden_states): """ From 6cb8a8aaeb45ea61aaf3694a512493cc1e978a7c Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 4 Feb 2026 12:30:47 -0800 Subject: [PATCH 22/92] tseted with qwen --- megatron/core/transformer/moe/router.py | 33 +++++++++++++++---------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 0730d13a49f..dadd7592ee5 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -693,12 +693,10 @@ def __init__( assert ( config.moe_router_num_groups is None ), f"InferenceTopKRouter requires moe_router_num_groups=None, got {config.moe_router_num_groups}" - assert ( - config.moe_router_score_function == "sigmoid" - ), f"InferenceTopKRouter requires moe_router_score_function='sigmoid', got '{config.moe_router_score_function}'" - assert ( - config.moe_router_enable_expert_bias is True - ), f"InferenceTopKRouter requires moe_router_enable_expert_bias=True, got {config.moe_router_enable_expert_bias}" + assert config.moe_router_score_function in [ + "sigmoid", + "softmax", + ], f"InferenceTopKRouter requires moe_router_score_function in ['sigmoid', 'softmax'], got '{config.moe_router_score_function}'" super().__init__(config=config, pg_collection=pg_collection) @@ -707,19 +705,28 @@ def __init__( def set_is_cuda_graphed_iteration(self, set_to: bool): self.is_cuda_graphed_iteration = set_to + @torch.compile() def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): logits = self.gating(input) # [num_tokens, num_experts] - # Apply sigmoid to get independent scores per expert - scores = torch.sigmoid(logits.float()).type_as(logits) # [num_tokens, num_experts] - - # Add expert bias for topk selection (helps with load balancing) - scores_for_routing = scores + self.expert_bias # [num_experts] broadcasted + # Apply score function to get scores per expert + if self.score_function == "sigmoid": + # Sigmoid: independent scores per expert + scores = torch.sigmoid(logits.float()).type_as(logits) # [num_tokens, num_experts] + else: # softmax + # Softmax: normalized scores across all experts + scores = torch.softmax(logits.float(), dim=-1).type_as(logits) # [num_tokens, num_experts] + + # Add expert bias for topk selection if enabled (helps with load balancing) + if self.expert_bias is not None: + scores_for_routing = scores + self.expert_bias # [num_experts] broadcasted + else: + scores_for_routing = scores - # Select top-k experts based on biased scores + # Select top-k experts based on scores (with or without bias) _, topk_indices = torch.topk(scores_for_routing, k=self.topk, dim=-1) # [num_tokens, topk] - # Gather the original sigmoid scores (without bias) for selected experts + # Gather the original scores (without bias) for selected experts topk_probs = torch.gather(scores, dim=-1, index=topk_indices) # [num_tokens, topk] # Normalize to get routing probabilities (sum to 1 per token) From b85e8fec83a5078166e9214096f1dedd514fca09 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 4 Feb 2026 19:21:03 -0800 Subject: [PATCH 23/92] add cutlass kernel --- .../transformer/moe/moe_layer_inference.py | 56 ++++++++++++++ .../moe/token_dispatcher_inference.py | 76 +------------------ 2 files changed, 60 insertions(+), 72 deletions(-) diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index 2991369b2b9..5c81a9a5d66 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -44,13 +44,17 @@ from typing import Optional import torch +import torch.nn.functional as F from megatron.core import utils +from megatron.core.activations import squared_relu from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.moe.moe_utils import get_default_pg_collection from megatron.core.transformer.moe.token_dispatcher_inference import InferenceAllGatherTokenDispatcher +import flashinfer.fused_moe as fused_moe +from flashinfer.fused_moe.core import ActivationType import logging @@ -146,4 +150,56 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens return forward_pass_output + def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor): + """Computes the output of the routed experts on the dispatched tokens. + + This method first post-processes the dispatched input to get permuted tokens + for each expert. It then passes the tokens through the local experts. + The output from the experts is preprocessed for the combine step. + """ + if not self.is_cuda_graphed_iteration: + # todo: can we go down the flashinfer path even if not cuda graphed? + return super().routed_experts_compute(hidden_states, probs) + + # Currently only squared_relu (non-gated) is supported with FlashInfer + assert not self.config.gated_linear_unit, ( + "FlashInfer MoE kernel currently only supports non-gated activations. " + f"Got gated_linear_unit={self.config.gated_linear_unit}" + ) + assert self.config.activation_func == squared_relu, ( + "FlashInfer MoE kernel currently only supports squared_relu activation. " + f"Got activation_func={self.config.activation_func}" + ) + + # Get dtype from input + output_dtype = hidden_states.dtype + + # Get expert weights from self.experts (GroupedMLP) + w1 = self.experts._fc1_weight + w2 = self.experts._fc2_weight + + # Get routing information (stored from route() step) + selected_experts = self.token_dispatcher.routing_map + routing_weights = probs + + # Get EP attributes + ep_size = utils.get_pg_size(self.ep_group) + ep_rank = utils.get_pg_rank(self.ep_group) + + # Call FlashInfer fused MoE kernel with Relu2 (squared ReLU) + output = fused_moe.cutlass_fused_moe( + hidden_states, + selected_experts.to(torch.int), + routing_weights.float(), + w1, + w2, + output_dtype, + quant_scales=None, + activation_type=ActivationType.Relu2, + ep_size=ep_size, + ep_rank=ep_rank, + )[0] + + return output, None + diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 78e751bd472..863203c8755 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -247,68 +247,10 @@ def token_dispatch(self, hidden_states, probs): def dispatch_postprocess(self, hidden_states, probs): - """After gathering in token_dispatch, this method identifies tokens for local experts and - permutes them for expert processing. - - Algorithm: - 1. Filter topk_indices to keep only those for local experts - 2. Shift valid indices to local coordinate system (0-indexed) - 3. Mark invalid indices with sentinel value - 4. Argsort to get token permutation that groups by expert - 5. Permute tokens and probs using this map - 6. Bincount to get tokens per expert """ - self.hidden_shape_before_permute = hidden_states.shape - - # self.routing_map is actually topk_indices: [num_tokens, topk] - topk_indices = self.routing_map - num_tokens, topk = topk_indices.shape - - # Shift global expert indices to local coordinate system using Triton kernel. - # For each index: - # - If in range [local_expert_indices[0], local_expert_indices[-1]]: - # shift to 0-based (e.g., expert 4 -> 0 if local_expert_indices[0] == 4) - # - Otherwise: mark with sentinel value (num_local_experts) - # This prepares indices for argsort which will group tokens by local expert. - # Result: [num_tokens, topk] with local indices or sentinels - adjusted_topk_indices = shift_topk_indices( - topk_indices, - local_start=self.local_expert_indices[0], - local_end=self.local_expert_indices[-1], - num_local_experts=self.num_local_experts, - ) - - # Flatten and argsort to get permutation that groups tokens by expert. - # After argsort, all tokens for expert 0 are at the beginning, then expert 1, etc. - # Sentinel values (num_local_experts) sort to the end. - # flat_indices: [num_tokens * topk] - # permutation: [num_tokens * topk] indices for reordering - flat_indices = adjusted_topk_indices.flatten() - self._cached_permutation = torch.argsort(flat_indices, stable=True) - - # Allocate workspace for permuted outputs - # Max possible tokens = num_tokens * min(topk, num_local_experts) - max_tokens = num_tokens * min(topk, self.num_local_experts) - - # Permute tokens and probs, count tokens per expert using Triton kernel - # Each thread block handles one output position, skipping sentinels - permuted_hidden, permuted_probs, tokens_per_expert = permute_tokens_and_probs( - hidden_states, - probs, - flat_indices, - self._cached_permutation, - num_local_experts=self.num_local_experts, - max_tokens=max_tokens, - ) - - # Cache data needed for unpermute in combine_preprocess - self._cached_flat_indices = flat_indices - self._cached_num_tokens = num_tokens - self._cached_topk = topk - - self.routing_map = None - self.local_probs = permuted_probs - return permuted_hidden, tokens_per_expert, permuted_probs + No op for flashinfer + """ + raise NotImplementedError def combine_preprocess(self, permuted_expert_outputs): @@ -318,17 +260,7 @@ def combine_preprocess(self, permuted_expert_outputs): Uses cached permutation and expert_assignments from dispatch_postprocess. Note: Probability weighting is handled by experts via moe_apply_probs_on_input. """ - # Unpermute expert outputs using cached data - output = unpermute_and_combine( - permuted_expert_outputs, - self._cached_flat_indices, - self._cached_permutation, - num_tokens=self._cached_num_tokens, - topk=self._cached_topk, - num_local_experts=self.num_local_experts, - ) - - return output + raise NotImplementedError def token_combine(self, hidden_states): """ From 98a4d9ff3b7bea5e4d54f476fa2fe0347bbe8ddb Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 5 Feb 2026 00:02:34 -0800 Subject: [PATCH 24/92] optimize dummy forwards --- .../inference/contexts/dynamic_context.py | 48 +++++++++++++++---- .../core/inference/engines/dynamic_engine.py | 2 +- .../text_generation_controller.py | 10 ++-- 3 files changed, 46 insertions(+), 14 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e1b55363b37..73cacb3caf2 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1095,7 +1095,9 @@ def add_dummy_requests_parallel( ) self.token_to_block_idx[token_slice] = dummy_block_idx + if self.is_hybrid_model: + torch.cuda.nvtx.range_push("allocate mamba states for dummy requests") for logical_idx, request_idx in enumerate(range(start_request_idx, end_request_idx)): mamba_idx = self.mamba_metadata.allocate_slot() if mamba_idx is None: @@ -1105,6 +1107,7 @@ def add_dummy_requests_parallel( self.mamba_conv_states[:, mamba_idx] = 0.0 self.mamba_ssm_states[:, mamba_idx] = 0.0 self.mamba_metadata.request_to_mamba_state_idx[request_idx] = mamba_idx + torch.cuda.nvtx.range_pop() self.active_token_count = token_end self.total_request_count = end_request_idx @@ -1174,7 +1177,10 @@ def num_decode_requests(self) -> int: return self.total_request_count - self.paused_request_count - self.num_prefill_requests def initialize_attention_state( - self, *, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None + self, + *, + construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, + ep_dummy_batch_dimensions: Optional[InferenceBatchDimensions] = None, ) -> None: """Initialize attention state so that every layer can use it. @@ -1184,15 +1190,26 @@ def initialize_attention_state( None. """ # if in recording mode, add dummy requests for cuda graph capture + torch.cuda.nvtx.range_push("init attention state") + if construct_graph_dimensions is not None: + assert ep_dummy_batch_dimensions is None + torch.cuda.nvtx.range_push("add dummy requests....") + rank = torch.distributed.get_rank() + logging.info(f"rank = {rank}: adding dummy requests.....!!!!!") self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) - - batch_dimensions = InferenceBatchDimensions( - token_count=self.active_token_count, - prefill_req_count=self.num_prefill_requests, - decode_req_count=self.num_decode_requests, - has_explicit_chunked_prefill_req=self.has_explicit_chunked_prefill_req, - ) + torch.cuda.nvtx.range_pop() + + if ep_dummy_batch_dimensions is not None: + batch_dimensions = ep_dummy_batch_dimensions + else: + batch_dimensions = InferenceBatchDimensions( + token_count=self.active_token_count, + prefill_req_count=self.num_prefill_requests, + decode_req_count=self.num_decode_requests, + has_explicit_chunked_prefill_req=self.has_explicit_chunked_prefill_req, + ) + self.batch_dimensions = batch_dimensions best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, @@ -1205,7 +1222,18 @@ def initialize_attention_state( if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph + if ep_dummy_batch_dimensions is not None: + # no requests should exist in the system + # we will only have padding tokens + # and dummy block idxes. + assert not self.active_token_count + assert not self.paused_request_count + self.total_request_count = ep_dummy_batch_dimensions.prefill_req_count + \ + ep_dummy_batch_dimensions.decode_req_count else: + if ep_dummy_batch_dimensions is not None: + return + padded_token_count = self.round_up_tokens(self.active_token_count) if self.is_decode_only(): padded_token_count = min( @@ -1279,7 +1307,7 @@ def initialize_attention_state( batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, ) - + torch.cuda.nvtx.range_push("mamba metadata update") if self.is_hybrid_model: active_mamba_indices_view = self.mamba_metadata.request_to_mamba_state_idx[active_slice] token_to_request_idx_view = self.token_to_request_idx[: self.active_token_count] @@ -1293,6 +1321,8 @@ def initialize_attention_state( batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, ) + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_pop() def reset(self) -> None: """Reset entire context. diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 134ce3b124d..5a244f26e7b 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1650,7 +1650,7 @@ async def _ep_group_has_work(self, local_work: int) -> bool: # Note that it is important to use a non-blocking asyncio-friendly all-reduce here. # The user may have other tasks running in the event loop that need to be serviced. # Do not using a torch.distributed blocking all-reduce here using nccl/gloo. - # We have tried that and it blocks the event loop is megatron-rl. + # We have tried that and it blocks the event loop in megatron-rl. max_global_work = await self.expert_parallel_zmq_communicator.all_reduce_max(local_work) else: max_global_work = local_work diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 62e3a57e0e4..a1ccdce70e5 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -494,7 +494,7 @@ def unpad_input_prompt_tokens( def _dynamic_step_context_init( self, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, - is_dummy_forward: bool = False, + ep_dummy_batch_dimensions: Optional[InferenceBatchDimensions] = None, ): """Initializes the inference context for dynamic batching. @@ -507,6 +507,8 @@ def _dynamic_step_context_init( input_ids (Tensor): The active input IDs. position_ids (Tensor): The active position IDs. """ + is_dummy_forward = ep_dummy_batch_dimensions is not None + context = self.inference_wrapped_model.inference_context active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -515,7 +517,8 @@ def _dynamic_step_context_init( model_config = get_model_config(unwrapped_model) # Initialize attention state. - context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions) + context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions, + ep_dummy_batch_dimensions=ep_dummy_batch_dimensions) # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels @@ -800,8 +803,7 @@ def dummy_forward(self): # a dummy cuda graph. input_ids, position_ids = self._dynamic_step_context_init( # try to use the smallest cuda-graph config for dummy forward - construct_graph_dimensions=min(context.cuda_graph_batch_dimensions_list), - is_dummy_forward=True, + ep_dummy_batch_dimensions=min(context.cuda_graph_batch_dimensions_list) ) # _dynamic_step_context_init tries to find a cuda-graph that is compatible From acbc84192a32bc87aa9271c5660676431138bd20 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 5 Feb 2026 02:56:56 -0800 Subject: [PATCH 25/92] bugfix in inference router --- megatron/core/transformer/moe/moe_layer_inference.py | 1 + megatron/core/transformer/moe/router.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py index 5c81a9a5d66..90ad04c5299 100644 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ b/megatron/core/transformer/moe/moe_layer_inference.py @@ -129,6 +129,7 @@ def deactivate_inference_token_dispatcher(self): self.token_dispatcher = self.old_token_dispatcher self.shared_expert_overlap = self.old_expert_overlap + # ==================== Simplified Forward Pass ==================== def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): """ diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index dadd7592ee5..0a4fd966762 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -733,6 +733,10 @@ def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = N if self.topk > 1: topk_probs = topk_probs / (topk_probs.sum(dim=-1, keepdim=True) + 1e-20) + # Apply scaling factor if configured + if self.config.moe_router_topk_scaling_factor: + topk_probs = topk_probs * self.config.moe_router_topk_scaling_factor + # NOTE: Return format differs from parent class for efficiency: # - Parent: Returns sparse tensors [num_tokens, num_experts] (routing_probs, routing_map) # - This: Returns dense tensors [num_tokens, topk] (topk_probs, topk_indices) @@ -750,9 +754,10 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No - probs: Normalized routing probabilities [num_tokens, topk] - top_indices: Selected expert indices [num_tokens, topk] """ - # Compute logits via gating network + # Maintain float32 expert bias (important for bf16/fp16) + self._maintain_float32_expert_bias() if not self.is_cuda_graphed_iteration: return super().forward(input, padding_mask) - + return self._forward(input, padding_mask) \ No newline at end of file From 986e2a1768dcf218abfc330161751c311f2ff95a Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 9 Feb 2026 11:23:08 -0800 Subject: [PATCH 26/92] latest --- .../inference/contexts/dynamic_context.py | 49 ++++++++++++++----- .../text_generation_controller.py | 6 ++- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 73cacb3caf2..b8a3b0aec55 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -425,6 +425,9 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.padded_active_request_count = 0 self.paused_tokens = None + # Debug: track last 5 steps' dummy forward status + self._dummy_forward_history = [] + # Block ids. self.max_kv_block_count = math.ceil(self.max_sequence_length / self.block_size_tokens) @@ -1176,6 +1179,7 @@ def num_decode_requests(self) -> int: """ return self.total_request_count - self.paused_request_count - self.num_prefill_requests + def initialize_attention_state( self, *, @@ -1220,16 +1224,29 @@ def initialize_attention_state( ) self._using_cuda_graph_this_step = best_graph is not None + # Track dummy forward history (last 5 steps) + is_dummy_forward = ep_dummy_batch_dimensions is not None + self._dummy_forward_history.append({ + 'is_dummy': is_dummy_forward, + 'ep_dims': str(ep_dummy_batch_dimensions) if is_dummy_forward else None, + 'active_tokens': self.active_token_count, + 'total_reqs': self.total_request_count, + 'using_graph': self._using_cuda_graph_this_step, + 'best_graph': str(best_graph) if best_graph else None, + }) + if len(self._dummy_forward_history) > 5: + self._dummy_forward_history.pop(0) + if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph if ep_dummy_batch_dimensions is not None: - # no requests should exist in the system - # we will only have padding tokens - # and dummy block idxes. - assert not self.active_token_count - assert not self.paused_request_count self.total_request_count = ep_dummy_batch_dimensions.prefill_req_count + \ ep_dummy_batch_dimensions.decode_req_count + # Zero out request_query_lengths to prevent stale data from causing + # out of bounds memory accesses in last_token_logits. + # When we move finished requests to the right, we never + # zero out their request lengths. + self.request_query_lengths[0:self.total_request_count].fill_(0) else: if ep_dummy_batch_dimensions is not None: return @@ -1423,13 +1440,21 @@ def last_token_logits(self, logits: Tensor) -> Tensor: # Last token logits. logits = logits.squeeze(0) - last_token_idxs = ( - torch.cumsum( - self.request_query_lengths[self.paused_request_count : self.total_request_count], - dim=0, - ) - - 1 - ) + query_lengths_slice = self.request_query_lengths[self.paused_request_count : self.total_request_count] + last_token_idxs = torch.cumsum(query_lengths_slice, dim=0) - 1 + + # Debug check for OOB + max_idx = last_token_idxs.max().item() if last_token_idxs.numel() > 0 else -1 + if max_idx >= logits.shape[0]: + print(f"OOB ERROR: max_idx={max_idx}, logits_dim={logits.shape[0]}") + print(f"query_lengths={query_lengths_slice}") + print(f"paused={self.paused_request_count}, total={self.total_request_count}") + print(f"active_token_count={self.active_token_count}, padded={self.padded_active_token_count}") + print(f"Dummy forward history (last 5 steps):") + for i, h in enumerate(self._dummy_forward_history): + print(f" Step -{len(self._dummy_forward_history)-i}: {h}") + raise RuntimeError(f"last_token_logits OOB: max_idx={max_idx} >= logits_dim={logits.shape[0]}") + last_token_logits = logits[last_token_idxs, :] return last_token_logits diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index a1ccdce70e5..8eceac104ba 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -786,8 +786,12 @@ def _dynamic_step_calculate_top_n_logprobs( def dummy_forward(self): """Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests.""" - + context = self.inference_wrapped_model.inference_context + # no requests should exist in the system + # we will only have padding tokens + # and dummy block idxes. + # context.reset() # if no cuda graphs, directly use dummy forward if not context.cuda_graph_batch_dimensions_list: # initialize symmetric memory if needed From bb8890dbb8a9dc5346da602bb7035309be25e376 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 13 Feb 2026 11:55:10 -0800 Subject: [PATCH 27/92] return usage characteristics from text gen server --- .../endpoints/completions.py | 178 ++++++++++-------- 1 file changed, 97 insertions(+), 81 deletions(-) diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py index b749205cdfd..97a509555a3 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py @@ -125,90 +125,106 @@ async def completions(): request_idx = 0 for record in batch_results: - for result in record.requests: - full_text = result.generated_text or "" - text_output = (prompts_as_strings[request_idx] + full_text) if echo else full_text - - logprobs_data = None - if sampling_params.return_log_probs: - # Get prompt tokens and logprobs - prompt_tokens_list = [] - if result.prompt_tokens is not None: - if hasattr(result.prompt_tokens, 'tolist'): - prompt_tokens_list = result.prompt_tokens.tolist() - else: - prompt_tokens_list = list(result.prompt_tokens) - - prompt_log_probs = getattr(result, 'prompt_log_probs', None) or [] - prompt_top_n_logprobs = getattr(result, 'prompt_top_n_logprobs', None) or [] - - # Get generated tokens and logprobs - generated_tokens_list = ( - list(result.generated_tokens) if result.generated_tokens else [] - ) - generated_log_probs = getattr(result, 'generated_log_probs', None) or [] - generated_top_n_logprobs = ( - getattr(result, 'generated_top_n_logprobs', None) or [] - ) - - if echo: - # When echo=True, include prompt tokens and their logprobs - # Prompt logprobs are for tokens [1:] (first token has no logprob) - all_token_ids = prompt_tokens_list + generated_tokens_list - tokens = [tokenizer.detokenize([tok]) for tok in all_token_ids] - - # Build token_logprobs: [None] for first token, then prompt logprobs, - # then generated logprobs - token_logprobs = [None] + list(prompt_log_probs) + list(generated_log_probs) - - # Build top_logprobs: [None] for first token, then prompt top_n, - # then generated top_n - top_logprobs = None - if prompt_top_n_logprobs or generated_top_n_logprobs: - top_logprobs = ( - [None] - + list(prompt_top_n_logprobs) - + list(generated_top_n_logprobs) - ) - - # Calculate text_offset: cumulative character positions starting from 0 - text_offset = [] - current_offset = 0 - for tok_str in tokens: - text_offset.append(current_offset) - current_offset += len(tok_str) + # for result in record.requests: + result = record.merge() + full_text = result.generated_text or "" + text_output = (prompts_as_strings[request_idx] + full_text) if echo else full_text + + logprobs_data = None + if sampling_params.return_log_probs: + # Get prompt tokens and logprobs + prompt_tokens_list = [] + if result.prompt_tokens is not None: + if hasattr(result.prompt_tokens, 'tolist'): + prompt_tokens_list = result.prompt_tokens.tolist() else: - # When echo=False, only return generated tokens and their logprobs - tokens = [tokenizer.detokenize([tok]) for tok in generated_tokens_list] - - # Prepend [None] to match OpenAI format - token_logprobs = [None] + list(generated_log_probs) - - # Build top_logprobs - top_logprobs = None - if generated_top_n_logprobs: - top_logprobs = [None] + list(generated_top_n_logprobs) - - # Calculate text_offset for generated tokens only - text_offset = [] - current_offset = 0 - for tok_str in tokens: - text_offset.append(current_offset) - current_offset += len(tok_str) - - logprobs_data = { - "token_logprobs": token_logprobs, - "tokens": tokens, - "text_offset": text_offset, - "top_logprobs": top_logprobs, - } - - choices.append( - {"index": request_idx, "text": text_output, "logprobs": logprobs_data} + prompt_tokens_list = list(result.prompt_tokens) + + prompt_log_probs = getattr(result, 'prompt_log_probs', None) or [] + prompt_top_n_logprobs = getattr(result, 'prompt_top_n_logprobs', None) or [] + + # Get generated tokens and logprobs + generated_tokens_list = ( + list(result.generated_tokens) if result.generated_tokens else [] ) - request_idx += 1 + generated_log_probs = getattr(result, 'generated_log_probs', None) or [] + generated_top_n_logprobs = ( + getattr(result, 'generated_top_n_logprobs', None) or [] + ) + + if echo: + # When echo=True, include prompt tokens and their logprobs + # Prompt logprobs are for tokens [1:] (first token has no logprob) + all_token_ids = prompt_tokens_list + generated_tokens_list + tokens = [tokenizer.detokenize([tok]) for tok in all_token_ids] + + # Build token_logprobs: [None] for first token, then prompt logprobs, + # then generated logprobs + token_logprobs = [None] + list(prompt_log_probs) + list(generated_log_probs) + + # Build top_logprobs: [None] for first token, then prompt top_n, + # then generated top_n + top_logprobs = None + if prompt_top_n_logprobs or generated_top_n_logprobs: + top_logprobs = ( + [None] + + list(prompt_top_n_logprobs) + + list(generated_top_n_logprobs) + ) + + # Calculate text_offset: cumulative character positions starting from 0 + text_offset = [] + current_offset = 0 + for tok_str in tokens: + text_offset.append(current_offset) + current_offset += len(tok_str) + else: + # When echo=False, only return generated tokens and their logprobs + tokens = [tokenizer.detokenize([tok]) for tok in generated_tokens_list] + + # Prepend [None] to match OpenAI format + token_logprobs = [None] + list(generated_log_probs) + + # Build top_logprobs + top_logprobs = None + if generated_top_n_logprobs: + top_logprobs = [None] + list(generated_top_n_logprobs) + + # Calculate text_offset for generated tokens only + text_offset = [] + current_offset = 0 + for tok_str in tokens: + text_offset.append(current_offset) + current_offset += len(tok_str) + + logprobs_data = { + "token_logprobs": token_logprobs, + "tokens": tokens, + "text_offset": text_offset, + "top_logprobs": top_logprobs, + } + + choices.append( + {"index": request_idx, "text": text_output, "logprobs": logprobs_data} + ) + request_idx += 1 + + prompt_tokens_total = sum(len(p) for p in prompts_as_tokens) + completion_tokens_total = sum( + len(result.generated_tokens) + for record in batch_results + for result in record.requests + if result.generated_tokens is not None + ) - return jsonify({"choices": choices}) + return jsonify({ + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens_total, + "completion_tokens": completion_tokens_total, + "total_tokens": prompt_tokens_total + completion_tokens_total, + }, + }) except ImportError as e: logger.warning(f"Could not import flask: {e}") From 9d062c7cd94c6f71536a3ed4b1cad563695a9046 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 15 Feb 2026 21:20:30 -0800 Subject: [PATCH 28/92] add vllm cg utils --- .../core/inference/batch_dimensions_utils.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index a2f10c6d11b..e7298758573 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -9,6 +9,7 @@ """ import math +import os from dataclasses import dataclass from typing import List, Optional, Tuple @@ -268,30 +269,48 @@ def _calculate_cuda_graph_token_counts( cuda_graph_max_tokens > 0 ), f"cuda_graph_max_tokens must be > 0, got {cuda_graph_max_tokens}" - # Cuda graph step size. - cuda_graph_step_size = cuda_graph_max_tokens / num_cuda_graphs - cuda_graph_step_size = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER * int( - math.ceil(int(cuda_graph_step_size) / CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER) - ) - # Make sure divisible by TP size - cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size - # round down cuda graph max tokens to be multiple of TP size cuda_graph_max_tokens = (cuda_graph_max_tokens // tp_size) * tp_size - # Cuda graph token counts. - if num_cuda_graphs == 1: - cuda_graph_token_counts = [cuda_graph_max_tokens] - else: - cuda_graph_token_counts = list( - range(cuda_graph_step_size, cuda_graph_max_tokens, cuda_graph_step_size) + if os.environ.get("VLLM_CG_CALC", "0") == "1": + # vLLM-style capture sizes: dense at small counts, coarser at larger counts. + cuda_graph_token_counts = [1, 2, 4] + list(range(8, 256, 8)) + list( + range(256, cuda_graph_max_tokens + 1, 16) ) - if ( - len(cuda_graph_token_counts) == 0 - or cuda_graph_token_counts[-1] != cuda_graph_max_tokens - ): + # Align each entry to TP size + cuda_graph_token_counts = list(dict.fromkeys( + math.ceil(s / tp_size) * tp_size for s in cuda_graph_token_counts + )) + # Clamp to max tokens + cuda_graph_token_counts = [s for s in cuda_graph_token_counts if s <= cuda_graph_max_tokens] + if not cuda_graph_token_counts or cuda_graph_token_counts[-1] != cuda_graph_max_tokens: cuda_graph_token_counts.append(cuda_graph_max_tokens) cuda_graph_token_counts.reverse() + else: + # Default: evenly-spaced token counts. + # Cuda graph step size. + cuda_graph_step_size = cuda_graph_max_tokens / num_cuda_graphs + cuda_graph_step_size = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER * int( + math.ceil( + int(cuda_graph_step_size) / CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER + ) + ) + # Make sure divisible by TP size + cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size + + # Cuda graph token counts. + if num_cuda_graphs == 1: + cuda_graph_token_counts = [cuda_graph_max_tokens] + else: + cuda_graph_token_counts = list( + range(cuda_graph_step_size, cuda_graph_max_tokens, cuda_graph_step_size) + ) + if ( + len(cuda_graph_token_counts) == 0 + or cuda_graph_token_counts[-1] != cuda_graph_max_tokens + ): + cuda_graph_token_counts.append(cuda_graph_max_tokens) + cuda_graph_token_counts.reverse() return cuda_graph_token_counts From 7a834ae1f2adcb9ad17e8193ab6dbba77d434ae8 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 15 Feb 2026 23:19:58 -0800 Subject: [PATCH 29/92] print engine time in ms instead of seconds --- megatron/core/inference/engines/dynamic_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 5a244f26e7b..f6eceeb2479 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -1338,7 +1338,7 @@ async def async_bookkeep( mem = torch.cuda.memory_stats() step_type = "decode" if context_state["is_decode_only"] else "non-decode" output_str = ( - "* rank %d | step %d | %s ... time: %.3f%s ... " + "* rank %d | step %d | %s ... time: %.3f ms%s ... " "reqs: a %d/%d, p %d, w %d, f %d, e %d ... " "blocks: a %d/%d, p %d/%d ... " "mem: tensors %d, alloc %.1f gb, res %.1f gb." @@ -1346,7 +1346,7 @@ async def async_bookkeep( self.rank, step_count, datetime.now().strftime("%H:%M:%S"), - step_time, + step_time * 1000, ( " [%s + real config %s + cuda graph %s]" % ( From 8281d86d39c7c2064512205925f0f7f2dc2ced91 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 15 Feb 2026 23:20:17 -0800 Subject: [PATCH 30/92] sleep 0 --- megatron/core/inference/engines/async_zmq_communicator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/inference/engines/async_zmq_communicator.py b/megatron/core/inference/engines/async_zmq_communicator.py index 7076bb283bd..d31894e4230 100644 --- a/megatron/core/inference/engines/async_zmq_communicator.py +++ b/megatron/core/inference/engines/async_zmq_communicator.py @@ -85,7 +85,7 @@ async def all_reduce_max(self, local_val: int) -> int: msg = self.gather_sock.recv(flags=zmq.NOBLOCK) values.append(struct.unpack('!i', msg)[0]) except zmq.Again: - await asyncio.sleep(0.001) # Yield to event loop + await asyncio.sleep(0) # Yield to event loop max_val = max(values) self.bcast_sock.send(struct.pack('!i', max_val)) @@ -100,7 +100,7 @@ async def all_reduce_max(self, local_val: int) -> int: msg = self.bcast_sock.recv(flags=zmq.NOBLOCK) return struct.unpack('!i', msg)[0] except zmq.Again: - await asyncio.sleep(0.001) # Yield to event loop + await asyncio.sleep(0) # Yield to event loop def close(self): """ From d4f00ca0e125d68a0a4480de75a04e8039fce76d Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 15 Feb 2026 23:20:42 -0800 Subject: [PATCH 31/92] add fused 3 tensor all gather --- .../torch_symm_triton/__init__.py | 2 +- .../torch_symm_triton/collectives.py | 120 ++++++++++++++++++ .../moe/token_dispatcher_inference.py | 25 ++-- 3 files changed, 129 insertions(+), 18 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index ca58663d9ec..282c98008f0 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from .collectives import multimem_all_gather, multimem_reduce_scatter +from .collectives import multimem_all_gather, multimem_all_gather_3, multimem_reduce_scatter from .fused_collectives import fused_multimem_rs_add_norm_ag diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 9d48fc8b341..6c482d39395 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -137,6 +137,126 @@ def multimem_all_gather( return output_tensor +# ── Fused 3-tensor all-gather ─────────────────────────────────────────────── +# Processes routing_map, probs, and hidden_states in a single kernel launch +# with a single barrier, eliminating 2 kernel launches + 2 barriers. + + +@triton.jit +def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE): + """One all-gather phase: load from local memory, multicast-store to symmetric buffer.""" + pid = tl.program_id(axis=0) + tid = get_flat_tid() + + numel_128 = numel // NUMEL_PER_THREAD + numel_per_rank = tl.cdiv(numel_128, WORLD_SIZE) + block_start = pid * BLOCK_SIZE + + while block_start < numel_per_rank: + offsets = block_start + tid + mask = offsets < numel_per_rank + + multicast_ptrs = ( + multicast_ptr.to(tl.pointer_type(tl.uint64)) + + byte_offset // 8 + + (RANK * numel_per_rank + offsets) * 2 + ) + local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 + (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) + st_128(multicast_ptrs, x, y, z, w, mask=mask, multicast_op=True) + + block_start += tl.num_programs(axis=0) * BLOCK_SIZE + + +@triton.jit +def _multimem_all_gather_3_kernel( + local_ptr_0, local_ptr_1, local_ptr_2, + multicast_ptr, + signal_pad_ptrs, + numel_0, byte_offset_0, + numel_1, byte_offset_1, + numel_2, byte_offset_2, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """ + Fused 3-tensor multicast all-gather. Processes three tensors in sequence + then synchronizes once, eliminating 2 kernel launches and 2 barriers + compared to three separate multimem_all_gather calls. + """ + # Phase 1: routing_map + _ag_phase(local_ptr_0, multicast_ptr, byte_offset_0, numel_0, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + + # Phase 2: probs + _ag_phase(local_ptr_1, multicast_ptr, byte_offset_1, numel_1, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + + # Phase 3: hidden_states + _ag_phase(local_ptr_2, multicast_ptr, byte_offset_2, numel_2, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + + # Single barrier for all three tensors + sync_threads() + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) + + +def multimem_all_gather_3( + output_0: torch.Tensor, input_0: torch.Tensor, byte_offset_0: int, + output_1: torch.Tensor, input_1: torch.Tensor, byte_offset_1: int, + output_2: torch.Tensor, input_2: torch.Tensor, byte_offset_2: int, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> None: + """ + Fused 3-tensor multicast all-gather. Equivalent to calling multimem_all_gather + three times but with a single kernel launch and a single barrier. + + All tensors must share the same symmetric memory handle and be BF16. + """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." + + config = { + "max_num_blocks": kwargs.get("max_num_blocks", 128), + "num_warps": kwargs.get("num_warps", 32), + "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), + } + + numel_per_thread = 128 // (input_0.element_size() * 8) + + assert output_0.numel() % numel_per_thread == 0, "Tensor 0 must be 128-bit aligned." + assert output_1.numel() % numel_per_thread == 0, "Tensor 1 must be 128-bit aligned." + assert output_2.numel() % numel_per_thread == 0, "Tensor 2 must be 128-bit aligned." + + # Size grid to the largest tensor + max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) + num_threads = triton.cdiv(max_numel // numel_per_thread, symm_mem_hdl.world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + + _multimem_all_gather_3_kernel[(num_blocks, 1, 1)]( + input_0.data_ptr(), input_1.data_ptr(), input_2.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel_0=output_0.numel(), byte_offset_0=byte_offset_0, + numel_1=output_1.numel(), byte_offset_1=byte_offset_1, + numel_2=output_2.numel(), byte_offset_2=byte_offset_2, + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) + + @triton.jit def _multimem_reduce_scatter_kernel( local_ptr, diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 863203c8755..76ba0d5d733 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -27,7 +27,7 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep from megatron.core.inference.communication.torch_symm_triton import ( - multimem_all_gather, + multimem_all_gather_3, multimem_reduce_scatter, ) @@ -207,30 +207,21 @@ def token_dispatch(self, hidden_states, probs): probs_dtype = probs.dtype hidden_dtype = hidden_states.dtype - # Use latency-optimized NVLS all-gather for routing_map, probs and hidden_states - # Pass byte_offset so kernel writes to correct location in multicast buffer - multimem_all_gather( + # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors + multimem_all_gather_3( ag_buffers["routing_map"].view(torch.bfloat16), self.routing_map.view(torch.bfloat16), - ag_buffers["handle"], - byte_offset=ag_buffers["routing_map_offset"], - ) - self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) - - multimem_all_gather( + ag_buffers["routing_map_offset"], ag_buffers["probs"].view(torch.bfloat16), probs.view(torch.bfloat16), - ag_buffers["handle"], - byte_offset=ag_buffers["probs_offset"], - ) - probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) - - multimem_all_gather( + ag_buffers["probs_offset"], ag_buffers["hidden_states"].view(torch.bfloat16), hidden_states.view(torch.bfloat16), + ag_buffers["hidden_states_offset"], ag_buffers["handle"], - byte_offset=ag_buffers["hidden_states_offset"], ) + self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) + probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) hidden_states = ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) else: # Fallback to NCCL for all tensors From 6251bb9e28181784693dc2fec41df58b0d8adfa8 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 17 Feb 2026 14:46:49 -0800 Subject: [PATCH 32/92] restore some delay in zmq asyncio --- megatron/core/inference/engines/async_zmq_communicator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/core/inference/engines/async_zmq_communicator.py b/megatron/core/inference/engines/async_zmq_communicator.py index d31894e4230..155cb6d002f 100644 --- a/megatron/core/inference/engines/async_zmq_communicator.py +++ b/megatron/core/inference/engines/async_zmq_communicator.py @@ -85,7 +85,7 @@ async def all_reduce_max(self, local_val: int) -> int: msg = self.gather_sock.recv(flags=zmq.NOBLOCK) values.append(struct.unpack('!i', msg)[0]) except zmq.Again: - await asyncio.sleep(0) # Yield to event loop + await asyncio.sleep(0.0001) # Yield to event loop max_val = max(values) self.bcast_sock.send(struct.pack('!i', max_val)) @@ -100,7 +100,7 @@ async def all_reduce_max(self, local_val: int) -> int: msg = self.bcast_sock.recv(flags=zmq.NOBLOCK) return struct.unpack('!i', msg)[0] except zmq.Again: - await asyncio.sleep(0) # Yield to event loop + await asyncio.sleep(0.0001) # Yield to event loop def close(self): """ From 72674ee1a8ca467a016e57f3569a74d7538761a5 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 20 Feb 2026 15:35:34 -0800 Subject: [PATCH 33/92] faster dummy ep cg codepath --- .../inference/contexts/dynamic_context.py | 39 ++++++++++++++++--- .../text_generation_controller.py | 4 +- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 9f7556f1312..237d5e3d60f 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1261,26 +1261,34 @@ def num_decode_requests(self) -> int: return self.total_request_count - self.paused_request_count - self.num_prefill_requests def initialize_attention_state( - self, *, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None + self, *, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, + is_expert_parallel_dummy_cuda_graph_step: bool = False ) -> None: """Initialize attention state so that every layer can use it. Args: construct_graph_dimensions (Optional[InferenceBatchDimensions]): The graph config to use for constructing the cuda graphs. + is_expert_parallel_dummy_cuda_graph_step (bool): Whether this is a dummy expert model parallel step. Return: None. """ self.is_creating_cuda_graphs = construct_graph_dimensions is not None + assert not (self.is_creating_cuda_graphs and is_expert_parallel_dummy_cuda_graph_step), "Dummy expert model parallel steps should not be creating cuda graphs." # If in CUDA graph creation mode, add dummy requests for CUDA graph capture if self.is_creating_cuda_graphs: self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) - batch_dimensions = InferenceBatchDimensions( - token_count=self.active_token_count, - prefill_req_count=self.num_prefill_requests, - decode_req_count=self.num_decode_requests, - ) + if is_expert_parallel_dummy_cuda_graph_step: + # attempt to use the smallest possible cuda graph for the dummy forward + smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list) + batch_dimensions = smallest_cuda_graph_dimensions + else: + batch_dimensions = InferenceBatchDimensions( + token_count=self.active_token_count, + prefill_req_count=self.num_prefill_requests, + decode_req_count=self.num_decode_requests, + ) self.batch_dimensions = batch_dimensions best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, @@ -1294,7 +1302,26 @@ def initialize_attention_state( if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph + if is_expert_parallel_dummy_cuda_graph_step: + # do minimum setup just so that we can run the dummy forward pass with + # a cuda graph. + # 1. Adjust total request count. Pretend that the smallest cuda graph + # dimensions represent the actual batch dimensions! + self.total_request_count = smallest_cuda_graph_dimensions.prefill_req_count + \ + smallest_cuda_graph_dimensions.decode_req_count + # 2. Reset request query lengths + # Zero out request_query_lengths to prevent stale data from causing + # out of bounds memory accesses in last_token_logits. + # This is needed because when we move finished requests to the right, we never + # zero out their request lengths. + self.request_query_lengths[0:self.total_request_count].fill_(0) else: + if is_expert_parallel_dummy_cuda_graph_step: + # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config + # could not find a compatible cuda graph for the dummy forward step. + # Now, we need not do the remaining setup. The controller + # will directly call the model forward pass with a single token. + return padded_token_count = self.round_up_tokens(self.active_token_count) if self.is_decode_only(): padded_token_count = min( diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index f56e5b1c761..2309e96aece 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -517,7 +517,8 @@ def _dynamic_step_context_init( model_config = get_model_config(unwrapped_model) # Initialize attention state. - context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions) + context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions, + is_expert_parallel_dummy_cuda_graph_step=is_dummy_forward) # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels @@ -846,7 +847,6 @@ def dummy_forward(self): # a dummy cuda graph. input_ids, position_ids = self._dynamic_step_context_init( # try to use the smallest cuda-graph config for dummy forward - construct_graph_dimensions=min(context.cuda_graph_batch_dimensions_list), is_dummy_forward=True, ) From 0b55c2d33f358b4ff3946a5a5b5151f869056ed4 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 20 Feb 2026 16:07:19 -0800 Subject: [PATCH 34/92] format --- .../inference/contexts/dynamic_context.py | 38 +++++++++++-------- .../text_generation_controller.py | 8 ++-- 2 files changed, 27 insertions(+), 19 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 237d5e3d60f..7d0675d4dd1 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1261,8 +1261,10 @@ def num_decode_requests(self) -> int: return self.total_request_count - self.paused_request_count - self.num_prefill_requests def initialize_attention_state( - self, *, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, - is_expert_parallel_dummy_cuda_graph_step: bool = False + self, + *, + construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, + is_expert_parallel_dummy_cuda_graph_step: bool = False, ) -> None: """Initialize attention state so that every layer can use it. @@ -1273,7 +1275,9 @@ def initialize_attention_state( None. """ self.is_creating_cuda_graphs = construct_graph_dimensions is not None - assert not (self.is_creating_cuda_graphs and is_expert_parallel_dummy_cuda_graph_step), "Dummy expert model parallel steps should not be creating cuda graphs." + assert not ( + self.is_creating_cuda_graphs and is_expert_parallel_dummy_cuda_graph_step + ), "Dummy expert model parallel steps should not be creating cuda graphs." # If in CUDA graph creation mode, add dummy requests for CUDA graph capture if self.is_creating_cuda_graphs: @@ -1302,26 +1306,28 @@ def initialize_attention_state( if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph - if is_expert_parallel_dummy_cuda_graph_step: - # do minimum setup just so that we can run the dummy forward pass with - # a cuda graph. - # 1. Adjust total request count. Pretend that the smallest cuda graph + if is_expert_parallel_dummy_cuda_graph_step: + # do minimum setup just so that we can run the dummy forward pass with + # a cuda graph. + # 1. Adjust total request count. Pretend that the smallest cuda graph # dimensions represent the actual batch dimensions! - self.total_request_count = smallest_cuda_graph_dimensions.prefill_req_count + \ - smallest_cuda_graph_dimensions.decode_req_count + self.total_request_count = ( + smallest_cuda_graph_dimensions.prefill_req_count + + smallest_cuda_graph_dimensions.decode_req_count + ) # 2. Reset request query lengths # Zero out request_query_lengths to prevent stale data from causing - # out of bounds memory accesses in last_token_logits. - # This is needed because when we move finished requests to the right, we never + # out of bounds memory accesses in last_token_logits. + # This is needed because when we move finished requests to the right, we never # zero out their request lengths. - self.request_query_lengths[0:self.total_request_count].fill_(0) + self.request_query_lengths[0 : self.total_request_count].fill_(0) else: if is_expert_parallel_dummy_cuda_graph_step: # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config - # could not find a compatible cuda graph for the dummy forward step. - # Now, we need not do the remaining setup. The controller - # will directly call the model forward pass with a single token. - return + # could not find a compatible cuda graph for the dummy forward step. + # Now, we need not do the remaining setup. The controller + # will directly call the model forward pass with a single token. + return padded_token_count = self.round_up_tokens(self.active_token_count) if self.is_decode_only(): padded_token_count = min( diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 2309e96aece..c7a64c93e9c 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -517,8 +517,10 @@ def _dynamic_step_context_init( model_config = get_model_config(unwrapped_model) # Initialize attention state. - context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions, - is_expert_parallel_dummy_cuda_graph_step=is_dummy_forward) + context.initialize_attention_state( + construct_graph_dimensions=construct_graph_dimensions, + is_expert_parallel_dummy_cuda_graph_step=is_dummy_forward, + ) # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels @@ -847,7 +849,7 @@ def dummy_forward(self): # a dummy cuda graph. input_ids, position_ids = self._dynamic_step_context_init( # try to use the smallest cuda-graph config for dummy forward - is_dummy_forward=True, + is_dummy_forward=True ) # _dynamic_step_context_init tries to find a cuda-graph that is compatible From 126d6c1a6abc163ccdbc2f7258a67b488e8e7e35 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 23 Feb 2026 15:34:19 -0800 Subject: [PATCH 35/92] refactor + make safer --- .../attention_context/mamba_metadata.py | 17 +++++ .../inference/contexts/dynamic_context.py | 74 ++++++++++++------- 2 files changed, 64 insertions(+), 27 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index d7fcf7436a2..3588e72292d 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -305,6 +305,23 @@ def allocate_slot(self) -> Optional[int]: return mamba_idx + def batch_allocate_slots(self, num_slots: int) -> Optional[torch.Tensor]: + """ + Allocates new slots for the given number of requests in the Mamba state buffers. + + Returns: + torch.Tensor: The indices of the allocated slots. + Returns None if not enough slots are available. + """ + if self.mamba_state_free_slot_count < num_slots: + return None + + # Get free slots + self.mamba_state_free_slot_count -= num_slots + mamba_idx = self.mamba_state_free_slots[self.mamba_state_free_slot_count:self.mamba_state_free_slot_count + num_slots] + + return mamba_idx + def free_slots(self, request_indices: torch.Tensor) -> None: """ Frees the Mamba state slots associated with the given request indices. diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 7d0675d4dd1..a75c24da0ed 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1260,6 +1260,41 @@ def num_decode_requests(self) -> int: """ return self.total_request_count - self.paused_request_count - self.num_prefill_requests + def add_dummy_requests_for_expert_parallel_step(self) -> None: + """Minimal context setup for a dummy EP forward pass. + + This is the fast alternative to add_dummy_requests_for_cudagraph_capture. + We only initialize state that is actually read during + initialize_attention_state and the subsequent forward pass. + """ + smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list) + # the smallest cuda graph is decode only. + assert smallest_cuda_graph_dimensions.prefill_req_count == 0 + + N = smallest_cuda_graph_dimensions.decode_req_count + dummy_block_idx = self.block_allocator.dummy_block_idx + + # do minimum setup just so that we can run the dummy forward pass with + # a cuda graph. + + # 1. Request counts and token count (decode-only: 1 token per request). + self.total_request_count = N + self.active_token_count = N + + # 2. Per-request state consumed by mha_metadata.update(). + self.request_query_lengths[0:N].fill_(1) + self.request_kv_length_offsets[0:N].fill_(0) + self.request_to_kv_block_ids[0:N, 0] = dummy_block_idx + + # 3. Token-level state (needed for padding / mamba slicing). + self.token_to_request_idx[0:N] = torch.arange( + 0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype + ) + + # 4. Mamba state: point every dummy request at slot 0. + if self.is_hybrid_model: + self.mamba_metadata.request_to_mamba_state_idx[0:N] = self.mamba_metadata.batch_allocate_slots(N) + def initialize_attention_state( self, *, @@ -1280,19 +1315,18 @@ def initialize_attention_state( ), "Dummy expert model parallel steps should not be creating cuda graphs." # If in CUDA graph creation mode, add dummy requests for CUDA graph capture - if self.is_creating_cuda_graphs: - self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) - if is_expert_parallel_dummy_cuda_graph_step: - # attempt to use the smallest possible cuda graph for the dummy forward - smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list) - batch_dimensions = smallest_cuda_graph_dimensions + self.add_dummy_requests_for_expert_parallel_step() + batch_dimensions = min(self.cuda_graph_batch_dimensions_list) else: + if self.is_creating_cuda_graphs: + self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) batch_dimensions = InferenceBatchDimensions( token_count=self.active_token_count, prefill_req_count=self.num_prefill_requests, decode_req_count=self.num_decode_requests, ) + self.batch_dimensions = batch_dimensions best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, @@ -1304,30 +1338,16 @@ def initialize_attention_state( ) self._using_cuda_graph_this_step = best_graph is not None + if is_expert_parallel_dummy_cuda_graph_step and not self.using_cuda_graph_this_step(): + # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config + # could not find a compatible cuda graph for the dummy forward step. + # Now, we need not do the remaining setup. The controller + # will directly call the model forward pass with a single token. + return + if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph - if is_expert_parallel_dummy_cuda_graph_step: - # do minimum setup just so that we can run the dummy forward pass with - # a cuda graph. - # 1. Adjust total request count. Pretend that the smallest cuda graph - # dimensions represent the actual batch dimensions! - self.total_request_count = ( - smallest_cuda_graph_dimensions.prefill_req_count - + smallest_cuda_graph_dimensions.decode_req_count - ) - # 2. Reset request query lengths - # Zero out request_query_lengths to prevent stale data from causing - # out of bounds memory accesses in last_token_logits. - # This is needed because when we move finished requests to the right, we never - # zero out their request lengths. - self.request_query_lengths[0 : self.total_request_count].fill_(0) else: - if is_expert_parallel_dummy_cuda_graph_step: - # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config - # could not find a compatible cuda graph for the dummy forward step. - # Now, we need not do the remaining setup. The controller - # will directly call the model forward pass with a single token. - return padded_token_count = self.round_up_tokens(self.active_token_count) if self.is_decode_only(): padded_token_count = min( From 7a7d78d82d2f747aae63c4675a46c511176e947c Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 23 Feb 2026 20:33:50 -0800 Subject: [PATCH 36/92] relegate to strict matching for qwen --- .../inference/contexts/dynamic_context.py | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index a75c24da0ed..1454e0d7063 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -543,6 +543,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) + # num_cuda_graphs == -1 creates decode cuda graphs of size [1,2,4,8] + # but mixed prefill cuda graphs still start from size [16], i.e. (inference_config.cuda_graph_mixed_prefill_count) + self.is_strict_matching = self.is_hybrid_model or (inference_config.num_cuda_graphs == -1) + self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -1261,11 +1265,15 @@ def num_decode_requests(self) -> int: return self.total_request_count - self.paused_request_count - self.num_prefill_requests def add_dummy_requests_for_expert_parallel_step(self) -> None: - """Minimal context setup for a dummy EP forward pass. + """Minimal context setup so an EP rank with no real requests can replay + an already-captured cuda graph without crashing or corrupting memory. + + This is the fast alternative to add_dummy_requests_for_cudagraph_capture + (which goes through the heavyweight add_dummy_requests_parallel path). + + We setup minimal state such the initialize_attention_state and the forward + pass can run without error. - This is the fast alternative to add_dummy_requests_for_cudagraph_capture. - We only initialize state that is actually read during - initialize_attention_state and the subsequent forward pass. """ smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list) # the smallest cuda graph is decode only. @@ -1274,25 +1282,27 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: N = smallest_cuda_graph_dimensions.decode_req_count dummy_block_idx = self.block_allocator.dummy_block_idx - # do minimum setup just so that we can run the dummy forward pass with - # a cuda graph. - # 1. Request counts and token count (decode-only: 1 token per request). self.total_request_count = N self.active_token_count = N + self.num_prefill_requests = 0 # 2. Per-request state consumed by mha_metadata.update(). self.request_query_lengths[0:N].fill_(1) self.request_kv_length_offsets[0:N].fill_(0) self.request_to_kv_block_ids[0:N, 0] = dummy_block_idx - # 3. Token-level state (needed for padding / mamba slicing). - self.token_to_request_idx[0:N] = torch.arange( - 0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype - ) + # 3. Token-level state consumed by the triton KV append kernel. + self.token_to_block_idx[0:N] = dummy_block_idx + self.token_to_local_position_within_kv_block[0:N] = 0 - # 4. Mamba state: point every dummy request at slot 0. if self.is_hybrid_model: + # 4. token_to_request_idx: needed by mamba_metadata.update() for hybrid models. + self.token_to_request_idx[0:N] = torch.arange( + 0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype + ) + + # 5. Mamba state: allocate slots for dummy requests. self.mamba_metadata.request_to_mamba_state_idx[0:N] = self.mamba_metadata.batch_allocate_slots(N) def initialize_attention_state( @@ -1317,21 +1327,22 @@ def initialize_attention_state( # If in CUDA graph creation mode, add dummy requests for CUDA graph capture if is_expert_parallel_dummy_cuda_graph_step: self.add_dummy_requests_for_expert_parallel_step() - batch_dimensions = min(self.cuda_graph_batch_dimensions_list) else: if self.is_creating_cuda_graphs: self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) - batch_dimensions = InferenceBatchDimensions( - token_count=self.active_token_count, - prefill_req_count=self.num_prefill_requests, - decode_req_count=self.num_decode_requests, - ) + + batch_dimensions = InferenceBatchDimensions( + token_count=self.active_token_count, + prefill_req_count=self.num_prefill_requests, + decode_req_count=self.num_decode_requests, + ) self.batch_dimensions = batch_dimensions + best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, - strict=self.is_hybrid_model, + strict=self.is_strict_matching, decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, From bf6f67808a1d3e1646d39150dd093c14c5444ffa Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 01:13:57 -0800 Subject: [PATCH 37/92] add unit test for ep syncs, and fix a bug in non strict matching --- .../core/inference/batch_dimensions_utils.py | 23 +- .../attention_context/mamba_metadata.py | 4 +- .../inference/contexts/dynamic_context.py | 20 +- megatron/inference/utils.py | 2 +- .../inference/test_batch_dimension_utils.py | 344 ++++++++++++++++++ 5 files changed, 380 insertions(+), 13 deletions(-) create mode 100644 tests/unit_tests/inference/test_batch_dimension_utils.py diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 1a202c35af5..71a6ee14f91 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -133,6 +133,7 @@ def adjust_batch_dims_for_expert_parallelism( strict: bool, decode_only_cuda_graphs: bool, explicit_chunked_prefill: bool, + cuda_graph_mixed_prefill_count: int, ep_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Optional["InferenceBatchDimensions"]: """Adjusted cuda graph batch dimensions for expert parallelism. @@ -157,6 +158,7 @@ def adjust_batch_dims_for_expert_parallelism( # all reduce local work across expert model parallel group is_non_decode = local_batch_dims.prefill_req_count > 0 + sync_tensor = torch.tensor( [ local_batch_dims.token_count, @@ -193,12 +195,22 @@ def adjust_batch_dims_for_expert_parallelism( adjusted_decode_req_count = ( int(sync_tensor[3].item()) if strict else local_batch_dims.decode_req_count ) + adjusted_token_count = int(sync_tensor[0].item()) + + # When any EP rank has prefill requests (non-strict mode), elevate + # the token count to be >= the smallest prefill/mixed cuda graph. + # This ensures decode-only ranks don't match a fine-grained decode + # graph while prefill ranks match a coarser mixed graph, which would + # produce inconsistent token counts across EP ranks. + if is_any_ep_rank_in_non_decode and not strict: + adjusted_token_count = max(adjusted_token_count, cuda_graph_mixed_prefill_count) adjusted_batch_dim = InferenceBatchDimensions( - token_count=int(sync_tensor[0].item()), + token_count=adjusted_token_count, prefill_req_count=adjusted_prefill_req_count, decode_req_count=adjusted_decode_req_count, ) + return adjusted_batch_dim @@ -360,6 +372,12 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ): cuda_graph_max_tokens = max_tokens + assert cuda_graph_max_tokens == max_requests, ( + f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests " + f"({max_requests}). This is required for correctly syncing EP ranks: " + f"prefill and decode graph pools must have the same token count granularity." + ) + if num_cuda_graphs != -1: # if -1, no need to adjust. This will be taken care of in # the _calculate_cuda_graph_token_counts function where we will generate @@ -456,6 +474,7 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int def match_graph_config( real_batch_dim: InferenceBatchDimensions, cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], + cuda_graph_mixed_prefill_count: int, strict: bool = False, decode_only_cuda_graphs: bool = False, explicit_chunked_prefill: bool = False, @@ -490,6 +509,7 @@ def match_graph_config( decode_only_cuda_graphs=decode_only_cuda_graphs, explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, + cuda_graph_mixed_prefill_count=cuda_graph_mixed_prefill_count, ) if adjusted_batch_dim is None: @@ -512,4 +532,5 @@ def match_graph_config( return None # then find the best batch dimension best_batch_dim = min(graph_batch_dims_applicable) + return best_batch_dim diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index 3588e72292d..bacaf882944 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -318,7 +318,9 @@ def batch_allocate_slots(self, num_slots: int) -> Optional[torch.Tensor]: # Get free slots self.mamba_state_free_slot_count -= num_slots - mamba_idx = self.mamba_state_free_slots[self.mamba_state_free_slot_count:self.mamba_state_free_slot_count + num_slots] + mamba_idx = self.mamba_state_free_slots[ + self.mamba_state_free_slot_count : self.mamba_state_free_slot_count + num_slots + ] return mamba_idx diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 1454e0d7063..2dc0fc2efe8 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -543,10 +543,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - # num_cuda_graphs == -1 creates decode cuda graphs of size [1,2,4,8] - # but mixed prefill cuda graphs still start from size [16], i.e. (inference_config.cuda_graph_mixed_prefill_count) - self.is_strict_matching = self.is_hybrid_model or (inference_config.num_cuda_graphs == -1) - + self.cuda_graph_mixed_prefill_count = inference_config.cuda_graph_mixed_prefill_count self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -1271,7 +1268,7 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: This is the fast alternative to add_dummy_requests_for_cudagraph_capture (which goes through the heavyweight add_dummy_requests_parallel path). - We setup minimal state such the initialize_attention_state and the forward + We setup minimal state such the initialize_attention_state and the forward pass can run without error. """ @@ -1302,8 +1299,10 @@ def add_dummy_requests_for_expert_parallel_step(self) -> None: 0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype ) - # 5. Mamba state: allocate slots for dummy requests. - self.mamba_metadata.request_to_mamba_state_idx[0:N] = self.mamba_metadata.batch_allocate_slots(N) + # 5. Mamba state: allocate slots for dummy requests. + self.mamba_metadata.request_to_mamba_state_idx[0:N] = ( + self.mamba_metadata.batch_allocate_slots(N) + ) def initialize_attention_state( self, @@ -1330,7 +1329,7 @@ def initialize_attention_state( else: if self.is_creating_cuda_graphs: self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) - + batch_dimensions = InferenceBatchDimensions( token_count=self.active_token_count, prefill_req_count=self.num_prefill_requests, @@ -1338,14 +1337,15 @@ def initialize_attention_state( ) self.batch_dimensions = batch_dimensions - + best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, - strict=self.is_strict_matching, + strict=self.is_hybrid_model, decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, + cuda_graph_mixed_prefill_count=self.cuda_graph_mixed_prefill_count, ) self._using_cuda_graph_this_step = best_graph is not None diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 92d153755fe..447d7290acc 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -303,7 +303,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=args.enable_chunked_prefill, metrics_writer=metrics_writer, - logging_step_interval=args.inference_logging_step_interval, + logging_step_interval=args.inference_logging_step_interval ) diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py new file mode 100644 index 00000000000..75a8d48b948 --- /dev/null +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -0,0 +1,344 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for CUDAGraphBatchDimensionBuilder.match_graph_config with expert parallelism. + +Run with 8 GPUs: + torchrun --nproc_per_node=8 -m pytest \ + tests/unit_tests/inference/test_batch_dimension_utils.py -xvs +""" + +import pytest +import torch +import torch.distributed as dist + +from megatron.core import parallel_state as ps +from megatron.core.inference.batch_dimensions_utils import ( + CUDAGraphBatchDimensionBuilder, + InferenceBatchDimensions, +) +from tests.unit_tests.test_utilities import Utils + +BD = InferenceBatchDimensions + +# Common config shared across tests +MAX_REQUESTS = 256 +MAX_TOKENS = 2048 +MAX_SEQ_LEN = 4096 +TP_SIZE = 1 +MIXED_PREFILL_COUNT = 4 + + +def _generate_graphs(num_cuda_graphs, + use_non_decode=True): + """Generate cuda graph batch dimensions using the builder.""" + graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=TP_SIZE, + num_cuda_graphs=num_cuda_graphs, + cuda_graph_max_tokens=MAX_REQUESTS, + cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + max_requests=MAX_REQUESTS, + max_tokens=MAX_TOKENS, + max_sequence_length=MAX_SEQ_LEN, + use_cuda_graphs_for_non_decode_steps=use_non_decode, + ) + return graph_list + + +def _match(real, graph_list, ep_group, strict=False, decode_only=False, + explicit_chunked_prefill=False): + return CUDAGraphBatchDimensionBuilder.match_graph_config( + real_batch_dim=real, + cuda_graph_batch_dimensions_list=graph_list, + strict=strict, + decode_only_cuda_graphs=decode_only, + explicit_chunked_prefill=explicit_chunked_prefill, + ep_group=ep_group, + cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + ) + + +def _assert_consistent_across_ranks(result, ep_group): + """Assert that the match result is the same on every EP rank. + + Either all ranks return None, or all ranks return a config with the + same token_count (which is what the all-reduce synchronises). + """ + if result is None: + flag = torch.zeros(1, dtype=torch.int32, device="cuda") + else: + flag = torch.ones(1, dtype=torch.int32, device="cuda") + + # If any rank got None, all must get None; if any rank got a match, all must. + flag_sum = flag.clone() + dist.all_reduce(flag_sum, op=dist.ReduceOp.SUM, group=ep_group) + ep_size = dist.get_world_size(ep_group) + assert flag_sum.item() == 0 or flag_sum.item() == ep_size, ( + f"Inconsistent match: {flag_sum.item()}/{ep_size} ranks got a match" + ) + + if result is not None: + tc = torch.tensor([result.token_count], dtype=torch.int32, device="cuda") + tc_max = tc.clone() + tc_min = tc.clone() + dist.all_reduce(tc_max, op=dist.ReduceOp.MAX, group=ep_group) + dist.all_reduce(tc_min, op=dist.ReduceOp.MIN, group=ep_group) + assert tc_max.item() == tc_min.item(), ( + f"Token count mismatch across EP ranks: min={tc_min.item()}, max={tc_max.item()}" + ) + + +class TestMatchGraphConfigWithEP: + """Tests for match_graph_config with expert parallelism. + + Uses the world group as the EP group (all 8 GPUs form one EP group). + """ + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=Utils.world_size, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @staticmethod + def _get_ep_group(): + """Return the EP group created by initialize_model_parallel.""" + return ps.get_expert_model_parallel_group() + + # ------------------------------------------------------------------ # + # 1. All ranks same decode batch → consistent match + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_uniform_decode_batch(self, num_cuda_graphs): + """All EP ranks have the same decode-only batch → should all match the same graph.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is not None, "Should find a matching graph for uniform decode batch" + assert result.token_count == 32 + + # ------------------------------------------------------------------ # + # 2. Different token counts across EP ranks → all-reduce takes max + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_varying_decode_token_counts(self, num_cuda_graphs): + """EP ranks have different decode token counts. The all-reduce + should take the max, and all ranks should match the same graph.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Each rank gets a different token count: 8, 16, 24, ... + token_count = (rank + 1) * 8 + real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=token_count) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is not None + assert result.token_count == (ep_group.size() * 8) + + # ------------------------------------------------------------------ # + # 3. decode_only_cuda_graphs=True, some ranks have prefill → all None + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_decode_only_graphs_with_mixed_ranks(self, num_cuda_graphs): + """When decode_only_cuda_graphs=True and at least one EP rank has a + prefill request, ALL ranks should get None (eager mode).""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Rank 0 has a mixed batch (prefill + decode), all others decode-only + if rank == 0: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=10) + else: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + + result = _match(real, graph_list, ep_group=ep_group, decode_only=True) + _assert_consistent_across_ranks(result, ep_group) + assert result is None, "All ranks should run eager when decode_only=True and some rank has prefill" + + # ------------------------------------------------------------------ # + # 4. explicit_chunked_prefill=True, some ranks prefill → all None + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_explicit_chunked_prefill_with_mixed_ranks(self, num_cuda_graphs): + """When explicit_chunked_prefill=True and some EP rank has prefill, + ALL ranks should get None (eager mode).""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + if rank == 0: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=10) + else: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + + result = _match( + real, graph_list, ep_group=ep_group, explicit_chunked_prefill=True + ) + _assert_consistent_across_ranks(result, ep_group) + assert result is None, "All ranks should run eager with explicit_chunked_prefill" + + # ------------------------------------------------------------------ # + # 5. Mixed prefill graphs with strict matching + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_strict_matching_with_mixed_prefill(self, num_cuda_graphs): + """With strict matching, request counts are synced across EP ranks + via all-reduce. All ranks should still get a consistent result.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Varying prefill/decode split across ranks + prefill = min(rank + 1, MIXED_PREFILL_COUNT) + decode = 16 - prefill + real = BD(token_count=64, prefill_req_count=prefill, decode_req_count=decode) + + result = _match(real, graph_list, ep_group=ep_group, strict=True) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 6. Non-strict matching with mixed prefill + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_non_strict_matching_with_mixed_prefill(self, num_cuda_graphs): + """Non-strict matching: prefill slots can serve decode. Token count + is synced across EP ranks; result must be consistent.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + prefill = min(rank + 1, MIXED_PREFILL_COUNT) + decode = 16 - prefill + real = BD(token_count=64, prefill_req_count=prefill, decode_req_count=decode) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 7. Mixed decode/prefill across ranks — strict matching + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_mixed_decode_and_prefill_ranks_strict(self, num_cuda_graphs): + """Some EP ranks are pure decode, others have prefill requests. + With strict matching the all-reduce syncs request counts to the + max across ranks. Result must be consistent.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Even ranks: pure decode (32 tokens) + # Odd ranks: mixed prefill (64 tokens, 2 prefill + 14 decode) + if rank % 2 == 0: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + else: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=14) + + result = _match(real, graph_list, ep_group=ep_group, strict=True) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 8. Mixed decode/prefill across ranks — non-strict matching + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_mixed_decode_and_prefill_ranks_non_strict(self, num_cuda_graphs): + """Some EP ranks are pure decode, others have prefill requests. + Non-strict matching only syncs token counts (not request counts). + Result must be consistent.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Even ranks: pure decode (32 tokens) + # Odd ranks: mixed prefill (64 tokens, 2 prefill + 14 decode) + if rank % 2 == 0: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + else: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=14) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 9. All ranks decode-only with decode_only_cuda_graphs → should match + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_decode_only_graphs_all_decode(self, num_cuda_graphs): + """When all EP ranks are decode-only and decode_only_cuda_graphs=True, + a match should be found.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + token_count = (rank + 1) * 4 + real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=token_count) + + result = _match(real, graph_list, ep_group=ep_group, decode_only=True) + _assert_consistent_across_ranks(result, ep_group) + assert result is not None, "All-decode batch with decode_only_cuda_graphs should match" + + # ------------------------------------------------------------------ # + # 10. Real batch exceeds all graphs → None on all ranks + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_oversized_batch_returns_none(self, num_cuda_graphs): + """When the real batch is larger than any available graph, all ranks + should get None.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + + # Token count exceeds MAX_TOKENS on all ranks + real = BD( + token_count=MAX_TOKENS + 100, + prefill_req_count=0, + decode_req_count=min(MAX_TOKENS + 100, MAX_REQUESTS), + ) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is None, "Oversized batch should not match any graph" + + # ------------------------------------------------------------------ # + # 11. One EP rank has huge batch → all-reduce lifts to max → no match + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) + def test_one_rank_oversized_forces_no_match(self, num_cuda_graphs): + """If one EP rank has a batch exceeding all graph capacities, the + all-reduce max lifts everyone → no match on any rank.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + if rank == 0: + # This rank has a batch that exceeds all graphs + real = BD( + token_count=MAX_TOKENS + 100, + prefill_req_count=0, + decode_req_count=min(MAX_TOKENS + 100, MAX_REQUESTS), + ) + else: + real = BD(token_count=8, prefill_req_count=0, decode_req_count=8) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is None, "All-reduce max from oversized rank should cause no match" From 3f2de16f9275bd50e8b10a1fd8e8d52deeb84a3c Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 01:20:32 -0800 Subject: [PATCH 38/92] linting --- .../core/inference/batch_dimensions_utils.py | 2 +- .../inference/test_batch_dimension_utils.py | 32 ++++++++----------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 71a6ee14f91..1303f61c9d2 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -158,7 +158,7 @@ def adjust_batch_dims_for_expert_parallelism( # all reduce local work across expert model parallel group is_non_decode = local_batch_dims.prefill_req_count > 0 - + sync_tensor = torch.tensor( [ local_batch_dims.token_count, diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index 75a8d48b948..d155bdf6d7f 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -2,10 +2,6 @@ """ Unit tests for CUDAGraphBatchDimensionBuilder.match_graph_config with expert parallelism. - -Run with 8 GPUs: - torchrun --nproc_per_node=8 -m pytest \ - tests/unit_tests/inference/test_batch_dimension_utils.py -xvs """ import pytest @@ -29,8 +25,7 @@ MIXED_PREFILL_COUNT = 4 -def _generate_graphs(num_cuda_graphs, - use_non_decode=True): +def _generate_graphs(num_cuda_graphs, use_non_decode=True): """Generate cuda graph batch dimensions using the builder.""" graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( tp_size=TP_SIZE, @@ -45,8 +40,9 @@ def _generate_graphs(num_cuda_graphs, return graph_list -def _match(real, graph_list, ep_group, strict=False, decode_only=False, - explicit_chunked_prefill=False): +def _match( + real, graph_list, ep_group, strict=False, decode_only=False, explicit_chunked_prefill=False +): return CUDAGraphBatchDimensionBuilder.match_graph_config( real_batch_dim=real, cuda_graph_batch_dimensions_list=graph_list, @@ -73,9 +69,9 @@ def _assert_consistent_across_ranks(result, ep_group): flag_sum = flag.clone() dist.all_reduce(flag_sum, op=dist.ReduceOp.SUM, group=ep_group) ep_size = dist.get_world_size(ep_group) - assert flag_sum.item() == 0 or flag_sum.item() == ep_size, ( - f"Inconsistent match: {flag_sum.item()}/{ep_size} ranks got a match" - ) + assert ( + flag_sum.item() == 0 or flag_sum.item() == ep_size + ), f"Inconsistent match: {flag_sum.item()}/{ep_size} ranks got a match" if result is not None: tc = torch.tensor([result.token_count], dtype=torch.int32, device="cuda") @@ -83,9 +79,9 @@ def _assert_consistent_across_ranks(result, ep_group): tc_min = tc.clone() dist.all_reduce(tc_max, op=dist.ReduceOp.MAX, group=ep_group) dist.all_reduce(tc_min, op=dist.ReduceOp.MIN, group=ep_group) - assert tc_max.item() == tc_min.item(), ( - f"Token count mismatch across EP ranks: min={tc_min.item()}, max={tc_max.item()}" - ) + assert ( + tc_max.item() == tc_min.item() + ), f"Token count mismatch across EP ranks: min={tc_min.item()}, max={tc_max.item()}" class TestMatchGraphConfigWithEP: @@ -166,7 +162,9 @@ def test_decode_only_graphs_with_mixed_ranks(self, num_cuda_graphs): result = _match(real, graph_list, ep_group=ep_group, decode_only=True) _assert_consistent_across_ranks(result, ep_group) - assert result is None, "All ranks should run eager when decode_only=True and some rank has prefill" + assert ( + result is None + ), "All ranks should run eager when decode_only=True and some rank has prefill" # ------------------------------------------------------------------ # # 4. explicit_chunked_prefill=True, some ranks prefill → all None @@ -185,9 +183,7 @@ def test_explicit_chunked_prefill_with_mixed_ranks(self, num_cuda_graphs): else: real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) - result = _match( - real, graph_list, ep_group=ep_group, explicit_chunked_prefill=True - ) + result = _match(real, graph_list, ep_group=ep_group, explicit_chunked_prefill=True) _assert_consistent_across_ranks(result, ep_group) assert result is None, "All ranks should run eager with explicit_chunked_prefill" From 05e872bb647fac4615242d4eea99f8ccd886d863 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 02:15:56 -0800 Subject: [PATCH 39/92] attempt to delete unnecessary modifications --- .../core/inference/batch_dimensions_utils.py | 129 +++--- .../inference/contexts/dynamic_context.py | 437 ++++++++++-------- .../engines/async_zmq_communicator.py | 4 +- .../text_generation_controller.py | 108 +++-- .../endpoints/completions.py | 50 +- megatron/core/ssm/mamba_block.py | 120 ++++- megatron/core/transformer/attention.py | 17 +- megatron/core/transformer/cuda_graphs.py | 20 +- .../core/transformer/moe/gpu_resident_ops.py | 110 ----- 9 files changed, 507 insertions(+), 488 deletions(-) delete mode 100644 megatron/core/transformer/moe/gpu_resident_ops.py diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index e7298758573..1a202c35af5 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -9,7 +9,6 @@ """ import math -import os from dataclasses import dataclass from typing import List, Optional, Tuple @@ -26,7 +25,6 @@ class InferenceBatchDimensions: token_count : number of total input tokens prefill_req_count : number of prefill requests decode_req_count : number of decode requests - has_explicit_chunked_prefill_req : whether the batch has an explicit chunked prefill request The batch dimensions are ordered by token_count, then by prefill_req_count, then by decode_req_count. @@ -36,7 +34,6 @@ class InferenceBatchDimensions: token_count: int = 0 prefill_req_count: int = 0 decode_req_count: int = 0 - has_explicit_chunked_prefill_req: bool = False def __str__(self): """ @@ -56,9 +53,6 @@ def is_applicable_for_batch_dim( for prefill or decode requests. Otherwise, prefill slots can only be used for prefill requests. """ - if real_batch_dim.has_explicit_chunked_prefill_req != self.has_explicit_chunked_prefill_req: - return False - if real_batch_dim.prefill_req_count == 0: return ( self.token_count >= real_batch_dim.token_count @@ -105,10 +99,6 @@ def is_valid(self, max_requests: int, max_sequence_length: int) -> bool: if self.token_count > self.prefill_req_count * max_sequence_length + self.decode_req_count: return False - # Check if there is an invalid chunked prefill request. - if self.prefill_req_count == 0 and self.has_explicit_chunked_prefill_req: - return False - return True def __hash__(self): @@ -116,14 +106,7 @@ def __hash__(self): Returns a hash of the batch dimension. In cuda graph quick matching, the batch dimension is used as a key in a dictionary. """ - return hash( - ( - self.token_count, - self.prefill_req_count, - self.decode_req_count, - self.has_explicit_chunked_prefill_req, - ) - ) + return hash((self.token_count, self.prefill_req_count, self.decode_req_count)) def __eq__(self, other: "InferenceBatchDimensions") -> bool: """ @@ -131,16 +114,10 @@ def __eq__(self, other: "InferenceBatchDimensions") -> bool: """ if other is None: return False - return ( - self.token_count, - self.prefill_req_count, - self.decode_req_count, - self.has_explicit_chunked_prefill_req, - ) == ( + return (self.token_count, self.prefill_req_count, self.decode_req_count) == ( other.token_count, other.prefill_req_count, other.decode_req_count, - other.has_explicit_chunked_prefill_req, ) @property @@ -155,6 +132,7 @@ def adjust_batch_dims_for_expert_parallelism( local_batch_dims, strict: bool, decode_only_cuda_graphs: bool, + explicit_chunked_prefill: bool, ep_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Optional["InferenceBatchDimensions"]: """Adjusted cuda graph batch dimensions for expert parallelism. @@ -164,6 +142,7 @@ def adjust_batch_dims_for_expert_parallelism( local_batch_dims: The local batch dimensions to adjust. strict: Whether to use strict matching for batch dimensions. decode_only_cuda_graphs: Whether CUDA graphs are only used for decode steps. + explicit_chunked_prefill: Whether chunked prefill is enabled with explicit requests ep_group: Optional expert parallel process group. If None, uses global parallel state. When using different EP sizes for inference vs training, pass the inference EP group explicitly. @@ -177,13 +156,11 @@ def adjust_batch_dims_for_expert_parallelism( return local_batch_dims # all reduce local work across expert model parallel group - has_explicit_chunked_prefill_req = local_batch_dims.has_explicit_chunked_prefill_req is_non_decode = local_batch_dims.prefill_req_count > 0 sync_tensor = torch.tensor( [ local_batch_dims.token_count, int(is_non_decode), - int(has_explicit_chunked_prefill_req), local_batch_dims.prefill_req_count, local_batch_dims.decode_req_count, ], @@ -195,7 +172,6 @@ def adjust_batch_dims_for_expert_parallelism( sync_tensor = sync_tensor.cpu() is_any_ep_rank_in_non_decode = sync_tensor[1].item() == 1 - any_ep_rank_has_explicit_chunked_prefill_req = sync_tensor[2].item() == 1 # We force eager mode for scenarios where some ranks will run with CUDA graphs # while others will not. Without this check, the all-to-all communication in the @@ -205,28 +181,23 @@ def adjust_batch_dims_for_expert_parallelism( # 1. If we only allow decode CUDA graphs but some ranks are running non-decode batches # 2. Some ranks are running explicit chunked prefill requests # (graphs are not recorded for batches with explicit chunked prefill requests) - if ( - decode_only_cuda_graphs and is_any_ep_rank_in_non_decode - ) or any_ep_rank_has_explicit_chunked_prefill_req: + if is_any_ep_rank_in_non_decode and (decode_only_cuda_graphs or explicit_chunked_prefill): return None # indicate no match, run in eager mode - assert not has_explicit_chunked_prefill_req - # If strict matching is enabled, we sync the request counts across EP ranks # to ensure the graph captures the maximum needed capacity. # TODO(ksanthanam): Add functional test for this scenario adjusted_prefill_req_count = ( - int(sync_tensor[3].item()) if strict else local_batch_dims.prefill_req_count + int(sync_tensor[2].item()) if strict else local_batch_dims.prefill_req_count ) adjusted_decode_req_count = ( - int(sync_tensor[4].item()) if strict else local_batch_dims.decode_req_count + int(sync_tensor[3].item()) if strict else local_batch_dims.decode_req_count ) adjusted_batch_dim = InferenceBatchDimensions( token_count=int(sync_tensor[0].item()), prefill_req_count=adjusted_prefill_req_count, decode_req_count=adjusted_decode_req_count, - has_explicit_chunked_prefill_req=False, ) return adjusted_batch_dim @@ -264,53 +235,54 @@ def _calculate_cuda_graph_token_counts( (tp_size=2, num_cuda_graphs=4, cuda_graph_max_tokens=1000) [1000, 752, 504, 256] """ + if num_cuda_graphs == -1: + # automatically determine the number of CUDA graphs to + # capture based on the `max_requests` value + cuda_graph_token_counts = ( + [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, cuda_graph_max_tokens + 1, 16)) + ) + # Align each entry to TP size + cuda_graph_token_counts = list( + dict.fromkeys(math.ceil(s / tp_size) * tp_size for s in cuda_graph_token_counts) + ) + # Clamp to max tokens + cuda_graph_token_counts = [ + s for s in cuda_graph_token_counts if s <= cuda_graph_max_tokens + ] + if not cuda_graph_token_counts or cuda_graph_token_counts[-1] != cuda_graph_max_tokens: + cuda_graph_token_counts.append(cuda_graph_max_tokens) + cuda_graph_token_counts.reverse() + return cuda_graph_token_counts + assert num_cuda_graphs >= 1, f"num_cuda_graphs must be >= 1, got {num_cuda_graphs}" assert ( cuda_graph_max_tokens > 0 ), f"cuda_graph_max_tokens must be > 0, got {cuda_graph_max_tokens}" + # Cuda graph step size. + cuda_graph_step_size = cuda_graph_max_tokens / num_cuda_graphs + cuda_graph_step_size = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER * int( + math.ceil(int(cuda_graph_step_size) / CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER) + ) + # Make sure divisible by TP size + cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size + # round down cuda graph max tokens to be multiple of TP size cuda_graph_max_tokens = (cuda_graph_max_tokens // tp_size) * tp_size - if os.environ.get("VLLM_CG_CALC", "0") == "1": - # vLLM-style capture sizes: dense at small counts, coarser at larger counts. - cuda_graph_token_counts = [1, 2, 4] + list(range(8, 256, 8)) + list( - range(256, cuda_graph_max_tokens + 1, 16) + # Cuda graph token counts. + if num_cuda_graphs == 1: + cuda_graph_token_counts = [cuda_graph_max_tokens] + else: + cuda_graph_token_counts = list( + range(cuda_graph_step_size, cuda_graph_max_tokens, cuda_graph_step_size) ) - # Align each entry to TP size - cuda_graph_token_counts = list(dict.fromkeys( - math.ceil(s / tp_size) * tp_size for s in cuda_graph_token_counts - )) - # Clamp to max tokens - cuda_graph_token_counts = [s for s in cuda_graph_token_counts if s <= cuda_graph_max_tokens] - if not cuda_graph_token_counts or cuda_graph_token_counts[-1] != cuda_graph_max_tokens: + if ( + len(cuda_graph_token_counts) == 0 + or cuda_graph_token_counts[-1] != cuda_graph_max_tokens + ): cuda_graph_token_counts.append(cuda_graph_max_tokens) cuda_graph_token_counts.reverse() - else: - # Default: evenly-spaced token counts. - # Cuda graph step size. - cuda_graph_step_size = cuda_graph_max_tokens / num_cuda_graphs - cuda_graph_step_size = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER * int( - math.ceil( - int(cuda_graph_step_size) / CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER - ) - ) - # Make sure divisible by TP size - cuda_graph_step_size = math.ceil(cuda_graph_step_size / tp_size) * tp_size - - # Cuda graph token counts. - if num_cuda_graphs == 1: - cuda_graph_token_counts = [cuda_graph_max_tokens] - else: - cuda_graph_token_counts = list( - range(cuda_graph_step_size, cuda_graph_max_tokens, cuda_graph_step_size) - ) - if ( - len(cuda_graph_token_counts) == 0 - or cuda_graph_token_counts[-1] != cuda_graph_max_tokens - ): - cuda_graph_token_counts.append(cuda_graph_max_tokens) - cuda_graph_token_counts.reverse() return cuda_graph_token_counts @@ -387,7 +359,12 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int or cuda_graph_max_tokens <= 0 ): cuda_graph_max_tokens = max_tokens - num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens) + + if num_cuda_graphs != -1: + # if -1, no need to adjust. This will be taken care of in + # the _calculate_cuda_graph_token_counts function where we will generate + # the token counts based on the max_tokens value and the step size. + num_cuda_graphs = min(max(num_cuda_graphs, 1), cuda_graph_max_tokens) # Calculate token counts for prefill and mixed graphs. # These need the full cuda_graph_max_tokens to handle variable-length sequences. @@ -481,6 +458,7 @@ def match_graph_config( cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], strict: bool = False, decode_only_cuda_graphs: bool = False, + explicit_chunked_prefill: bool = False, ep_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Optional[InferenceBatchDimensions]: """ @@ -494,6 +472,7 @@ def match_graph_config( decode_only_cuda_graphs: Used by expert parallel matching. If this is true, and one of the EP ranks is running a non-decode step, we elect to run in eager mode instead of matching a decode-only cuda graph. + explicit_chunked_prefill: Whether chunked prefill is enabled with explicit requests ep_group: Optional expert parallel process group. If None, uses global parallel state. When using different EP sizes for inference vs training, pass the inference EP group explicitly. @@ -509,6 +488,7 @@ def match_graph_config( real_batch_dim, strict=strict, decode_only_cuda_graphs=decode_only_cuda_graphs, + explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, ) @@ -518,6 +498,9 @@ def match_graph_config( # in that case, all ranks have to run in eager mode return None + if explicit_chunked_prefill and real_batch_dim.prefill_req_count > 0: + return None + # first filter out batch dimensions with smaller token count, prefill req count, # or decode req count, as they are not applicable graph_batch_dims_applicable = [ diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index b8a3b0aec55..9f7556f1312 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -15,7 +15,7 @@ CUDAGraphBatchDimensionBuilder, InferenceBatchDimensions, ) -from megatron.core.inference.config import InferenceConfig +from megatron.core.inference.config import InferenceConfig, KVCacheManagementMode from megatron.core.inference.inference_request import DynamicInferenceRequest from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.unified_memory import ( @@ -35,6 +35,7 @@ from .attention_context.mha_metadata import GraphedMHAMetadata, NonGraphedMHAMetadata from .base_context import BaseInferenceContext from .dynamic_block_allocator import BlockAllocator +from .routing_metadata import RoutingMetadata try: from .fused_kv_append_kernel import triton_append_key_value_cache @@ -347,10 +348,12 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC mamba_states_memory_per_request *= self.num_mamba_layers mamba_states_memory_per_request *= dtype_size_bytes - # Unified memory. + # Unified memory and general tensor management. self.unified_memory_level = inference_config.unified_memory_level - self.persist_cuda_graphs = inference_config.persist_cuda_graphs - if self.unified_memory_level > 0: + self.static_kv_memory_pointers = inference_config.static_kv_memory_pointers + self.kv_cache_management_mode = inference_config.kv_cache_management_mode + + if self.unified_memory_level != 0: try: self.unified_memory_mempool = create_unified_mempool() except UnifiedMemoryUnsupportedError: @@ -359,6 +362,23 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC "Unified memory requested but not available; defaulting to GPU memory." ) self.unified_memory_level = 0 + # If we are in a mode that requires static KV memory pointers, + # we must have either UVM or torch_memory_saver. + if ( + self.static_kv_memory_pointers + and self.kv_cache_management_mode != KVCacheManagementMode.PERSIST + ): + assert HAVE_TORCH_MEMORY_SAVER or self.unified_memory_level != 0, ( + "Static KV memory pointers require UVM or torch_memory_saver when not persisted. " + "Use --rl-kv-cache-management-mode=persist, UVM, or install torch_memory_saver." + ) + + # When not using `torch_memory_saver`, we manually offload/restore tensors. + # We use storage resize, similar to the logic in `core/distributed/param_and_grad_buffer.py` + self._offloadable_tensor_names: set[str] = set() + self._offloadable_cpu_backups: dict[str, torch.Tensor] = {} + self._offloadable_storage_sizes: dict[str, int] = {} + self._uses_torch_memory_saver: bool = False # Initialize block allocator. buffer_size_bytes = int(inference_config.buffer_size_gb * 1024**3) @@ -367,13 +387,33 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC if inference_config.paused_buffer_size_gb is None else int(inference_config.paused_buffer_size_gb * 1024**3) ) - # TODO: Add parameter to control fraction of memory assigned to KV cache - # versus Mamba state. - block_count = buffer_size_bytes // (self.block_size_bytes + mamba_states_memory_per_request) - block_count = max(2, block_count) # need >= 1 active block + 1 dummy block - paused_block_count = paused_buffer_size_bytes // ( - self.block_size_bytes + mamba_states_memory_per_request - ) + + mamba_max_requests = float('inf') + + if (mamba_memory_ratio := inference_config.mamba_memory_ratio) is not None: + assert self.is_hybrid_model + assert mamba_memory_ratio > 0 and mamba_memory_ratio < 1 + + # Calculate total memory before partition + total_memory = buffer_size_bytes + paused_buffer_size_bytes + mamba_memory_bytes = total_memory * mamba_memory_ratio + mamba_max_requests = int(mamba_memory_bytes // mamba_states_memory_per_request) + + # Reduce buffer sizes for KV cache + buffer_size_bytes = int(buffer_size_bytes * (1.0 - mamba_memory_ratio)) + paused_buffer_size_bytes = int(paused_buffer_size_bytes * (1.0 - mamba_memory_ratio)) + + block_count = buffer_size_bytes // self.block_size_bytes + block_count = max(2, block_count) # need >= 1 active block + 1 dummy block + paused_block_count = paused_buffer_size_bytes // self.block_size_bytes + else: + block_count = buffer_size_bytes // ( + self.block_size_bytes + mamba_states_memory_per_request + ) + block_count = max(2, block_count) # need >= 1 active block + 1 dummy block + paused_block_count = paused_buffer_size_bytes // ( + self.block_size_bytes + mamba_states_memory_per_request + ) # If using pipeline parallelism synchronize the total block count in case the # pipeline stages have different layer allocations. Non-uniform block counts @@ -425,9 +465,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.padded_active_request_count = 0 self.paused_tokens = None - # Debug: track last 5 steps' dummy forward status - self._dummy_forward_history = [] - # Block ids. self.max_kv_block_count = math.ceil(self.max_sequence_length / self.block_size_tokens) @@ -435,12 +472,21 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC if inference_config.max_requests is None: # Maximize compute utilization by defaulting to 1 block per request. self.max_requests = self.block_allocator.total_count - 1 # -1 for dummy block + + # Adjust max_requests for Mamba memory constraints if necessary + if self.is_hybrid_model and mamba_max_requests < self.max_requests: + self.max_requests = int(mamba_max_requests) + self.max_requests = self.max_requests // tp_size * tp_size self.max_requests = self.max_requests // self.REQUEST_ROUNDER * self.REQUEST_ROUNDER else: # User can control request overflow via max_requests. self.max_requests = inference_config.max_requests + assert ( + self.max_requests % tp_size == 0 + ), f"max_requests must be divisible by tp_size ({tp_size}), but got {self.max_requests}" + self.max_tokens = inference_config.max_tokens or self.DEFAULT_MAX_TOKENS assert self.max_tokens >= self.max_requests, ( @@ -455,6 +501,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC self.graph_attn_metadata = {} self.non_graph_attn_metadata = {} self.active_attn_metadata = None + self.is_creating_cuda_graphs = False self.graph_attn_metadata["mha_metadata"] = GraphedMHAMetadata( block_count_total=self.block_allocator.total_count, @@ -472,6 +519,13 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC max_seqlen=self.max_sequence_length, ) + self.moe_enable_routing_replay = model_config.moe_enable_routing_replay + if self.moe_enable_routing_replay: + assert ( + model_config.num_moe_experts is not None + ), "Router recording/replay requested but no MoE experts specified!" + self.moe_routing_metadata = RoutingMetadata(self, model_config.moe_router_topk) + # CUDA graph config list self.use_cuda_graphs_for_non_decode_steps = ( inference_config.use_cuda_graphs_for_non_decode_steps @@ -489,16 +543,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - # Whether to offload the KV cache. Determines where the KV cache is allocated within memory. - self.offload_kv_cache = inference_config.offload_kv_cache - assert not ( - self.offload_kv_cache and self.unified_memory_level - ), "The KV cache should not be instantiated in unified memory when it is offloaded during training." - self._using_cuda_graph_this_step = False # Deal with chunked prefill + self.enable_chunked_prefill = inference_config.enable_chunked_prefill self.chunked_prefill_request_id = -1 - self.has_explicit_chunked_prefill_req = False # FlashInfer. if inference_config.use_flashinfer_fused_rope is True: @@ -510,7 +558,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC # Allocate GPU state. self.is_tensor_state_allocated = False self.is_symmetric_memory_initialized = False - self.allocate_all_tensors(is_init=True) + self.initialize_all_tensors() # Print info. logging.info( @@ -521,23 +569,76 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - def allocate_all_tensors(self, *, is_init: bool) -> None: - """Allocate GPU state. - - This method is used for both 1) initial allocation, and 2) resuming the - GPU state after a suspend. - - Args: - is_init (bool): True if this is being called from `__init__()`. - """ - - # Only allocate tensors when not using unified memory at all (level 0), - # or for initial allocation during `__init__()`. For levels 1 and 2, we do - # not perform any explicit allocations or deallocations after the initial - # call to `__init__()`. - if self.unified_memory_level != 0 and not is_init: - return + def _allocate_memory_buffer(self): + """Allocate the KV cache memory buffer.""" + if self.cache_mla_latent: + self.memory_buffer = torch.empty( + ( + self.num_attention_layers, + self.block_allocator.total_count, + self.block_size_tokens, + self.kv_reduced_dim, + ), + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + else: + self.memory_buffer = torch.empty( + ( + 2, # key and value + self.num_attention_layers, + self.block_allocator.total_count, + self.block_size_tokens, + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ), + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + if ( + self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD + and not self._uses_torch_memory_saver + ): + assert self.unified_memory_level == 0 + self._offloadable_tensor_names.add("memory_buffer") + self._offloadable_cpu_backups["memory_buffer"] = torch.empty_like( + self.memory_buffer, device="cpu" + ).pin_memory() + + def _allocate_mamba_states(self): + """Allocate Mamba states for hybrid models.""" + if self.is_hybrid_model: + self.mamba_metadata = MambaMetadata( + max_requests=self.max_requests, max_tokens=self.max_tokens + ) + self.mamba_conv_states = torch.empty( + (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + self.mamba_ssm_states = torch.empty( + (self.num_mamba_layers, self.max_requests) + self.mamba_ssm_states_shape, + dtype=self.params_dtype, + device=torch.cuda.current_device(), + ) + if ( + self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD + and not self._uses_torch_memory_saver + ): + assert self.unified_memory_level == 0 + self._offloadable_tensor_names.add("mamba_conv_states") + self._offloadable_cpu_backups["mamba_conv_states"] = torch.empty_like( + self.mamba_conv_states, device="cpu" + ).pin_memory() + self._offloadable_tensor_names.add("mamba_ssm_states") + self._offloadable_cpu_backups["mamba_ssm_states"] = torch.empty_like( + self.mamba_ssm_states, device="cpu" + ).pin_memory() + else: + self.mamba_metadata = None + def initialize_all_tensors(self) -> None: + """Allocate all GPU state during initial construction.""" # Mark allocated. if self.is_tensor_state_allocated: return @@ -547,7 +648,7 @@ def allocate_all_tensors(self, *, is_init: bool) -> None: for key in vars(self).keys(): value = getattr(self, key) assert not isinstance(value, torch.Tensor), ( - "All tensors should be allocated within `allocate_all_tensors()." + "All tensors should be allocated within `initialize_all_tensors()`. " f"Please move tensor '{key}'." ) @@ -593,103 +694,86 @@ def allocate_all_tensors(self, *, is_init: bool) -> None: self.token_to_position_in_request = torch.empty_like(self.token_to_input_ids) self.token_to_local_position_within_kv_block = torch.empty_like(self.token_to_input_ids) - # Memory buffer. - def allocate_memory_buffer(): - """Allocate the memory buffer. This function is called below within - `with ctx_manager:`.""" - if self.cache_mla_latent: - self.memory_buffer = torch.empty( - ( - self.num_attention_layers, - self.block_allocator.total_count, - self.block_size_tokens, - self.kv_reduced_dim, - ), - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - else: - ctx = ( - torch_memory_saver.region(tag="kv_cache", enable_cpu_backup=True) - if HAVE_TORCH_MEMORY_SAVER and self.offload_kv_cache - else nullcontext() - ) - - with ctx: - self.memory_buffer = torch.empty( - ( - 2, # key and value - self.num_attention_layers, - self.block_allocator.total_count, - self.block_size_tokens, - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ), - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - - # Optional state tensors for hybrid models - def allocate_mamba_states(): - """Allocate Mamba states. This function is called below within - `with ctx_manager:`.""" - if self.is_hybrid_model: - self.mamba_metadata = MambaMetadata( - max_requests=self.max_requests, max_tokens=self.max_tokens - ) - self.mamba_conv_states = torch.empty( - (self.num_mamba_layers, self.max_requests) + self.mamba_conv_states_shape, - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - self.mamba_ssm_states = torch.empty( - (self.num_mamba_layers, self.max_requests) + self.mamba_ssm_states_shape, - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - - else: - self.mamba_metadata = None - - # Allocate `ctx_manager`-managed buffers. (For currently unknown reasons, - # `ctx_manager` can only be used once.) - ctx_manager = ( - torch.cuda.use_mem_pool(self.unified_memory_mempool) - if self.unified_memory_level > 0 - else nullcontext() + # Allocate large non-graphed buffers. + need_static_addr = ( + self.static_kv_memory_pointers + and self.kv_cache_management_mode != KVCacheManagementMode.PERSIST ) + + ctx_manager = nullcontext() + if self.unified_memory_level != 0: + ctx_manager = torch.cuda.use_mem_pool(self.unified_memory_mempool) + elif HAVE_TORCH_MEMORY_SAVER and need_static_addr: + ctx_manager = torch_memory_saver.region( + tag="inference_context", + enable_cpu_backup=(self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD), + ) + self._uses_torch_memory_saver = True with ctx_manager: - allocate_memory_buffer() - allocate_mamba_states() + self._allocate_memory_buffer() + self._allocate_mamba_states() # Reset attention and Mamba state. self.reset_attention_state() self.reset_mamba_state() - def deallocate_all_tensors(self): - """Deallocate GPU state. + def reinitialize_inference_state_buffers(self): + """Restore large tensors (KV cache, Mamba states) after a suspend. - This method is used for suspending the dynamic engine. + Called by the engine during `resume()`. Initial allocation is in `initialize_all_tensors()`. """ + if self.is_tensor_state_allocated: + return + self.is_tensor_state_allocated = True - # Only deallocate tensors when not using unified memory at all (level 0). - # For levels 1 and 2, we do not perform any explicit allocations or - # deallocations after the initial call to `__init__()`. - if self.unified_memory_level != 0: + if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST: + return + + if self.unified_memory_level != 0 or self._uses_torch_memory_saver: + if self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.reset() + if self._uses_torch_memory_saver: + torch_memory_saver.resume("inference_context") return - # Mark deallocated. + if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: + for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): + tensor.storage().resize_(self._offloadable_storage_sizes[name]) + tensor.copy_(self._offloadable_cpu_backups[name], non_blocking=True) + elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + self.is_tensor_state_allocated = False + self.initialize_all_tensors() + + def deallocate_inference_state_buffers(self): + """Deallocate large tensors (KV cache, Mamba states) during suspend. + + Called by the engine during `suspend()`. Mirror to `reinitialize_inference_state_buffers()`. + """ if not self.is_tensor_state_allocated: return self.is_tensor_state_allocated = False - # Delete all tensor attributes. - # TODO(@lmcafee): check that device == 'cuda'? - keys = list(vars(self).keys()) - for key in keys: - value = getattr(self, key) - if isinstance(value, torch.Tensor): - delattr(self, key) + if self.kv_cache_management_mode == KVCacheManagementMode.PERSIST: + return + + if self.unified_memory_level != 0: + return + + if self._uses_torch_memory_saver: + torch_memory_saver.pause("inference_context") + return + + if self.kv_cache_management_mode == KVCacheManagementMode.OFFLOAD: + for name, tensor in ((n, getattr(self, n)) for n in self._offloadable_tensor_names): + self._offloadable_storage_sizes[name] = tensor.storage().size() + self._offloadable_cpu_backups[name].copy_(tensor, non_blocking=True) + tensor.storage().resize_(0) + elif self.kv_cache_management_mode == KVCacheManagementMode.RECOMPUTE: + # TODO(@lmcafee): check that device == 'cuda'? + for key in list(vars(self).keys()): + value = getattr(self, key) + if isinstance(value, torch.Tensor): + delattr(self, key) @classmethod def round_up_tokens(cls, value, tp_size=None): @@ -1098,9 +1182,7 @@ def add_dummy_requests_parallel( ) self.token_to_block_idx[token_slice] = dummy_block_idx - if self.is_hybrid_model: - torch.cuda.nvtx.range_push("allocate mamba states for dummy requests") for logical_idx, request_idx in enumerate(range(start_request_idx, end_request_idx)): mamba_idx = self.mamba_metadata.allocate_slot() if mamba_idx is None: @@ -1110,7 +1192,6 @@ def add_dummy_requests_parallel( self.mamba_conv_states[:, mamba_idx] = 0.0 self.mamba_ssm_states[:, mamba_idx] = 0.0 self.mamba_metadata.request_to_mamba_state_idx[request_idx] = mamba_idx - torch.cuda.nvtx.range_pop() self.active_token_count = token_end self.total_request_count = end_request_idx @@ -1179,12 +1260,8 @@ def num_decode_requests(self) -> int: """ return self.total_request_count - self.paused_request_count - self.num_prefill_requests - def initialize_attention_state( - self, - *, - construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, - ep_dummy_batch_dimensions: Optional[InferenceBatchDimensions] = None, + self, *, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None ) -> None: """Initialize attention state so that every layer can use it. @@ -1193,64 +1270,31 @@ def initialize_attention_state( Return: None. """ - # if in recording mode, add dummy requests for cuda graph capture - torch.cuda.nvtx.range_push("init attention state") - - if construct_graph_dimensions is not None: - assert ep_dummy_batch_dimensions is None - torch.cuda.nvtx.range_push("add dummy requests....") - rank = torch.distributed.get_rank() - logging.info(f"rank = {rank}: adding dummy requests.....!!!!!") + self.is_creating_cuda_graphs = construct_graph_dimensions is not None + + # If in CUDA graph creation mode, add dummy requests for CUDA graph capture + if self.is_creating_cuda_graphs: self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) - torch.cuda.nvtx.range_pop() - - if ep_dummy_batch_dimensions is not None: - batch_dimensions = ep_dummy_batch_dimensions - else: - batch_dimensions = InferenceBatchDimensions( - token_count=self.active_token_count, - prefill_req_count=self.num_prefill_requests, - decode_req_count=self.num_decode_requests, - has_explicit_chunked_prefill_req=self.has_explicit_chunked_prefill_req, - ) - + + batch_dimensions = InferenceBatchDimensions( + token_count=self.active_token_count, + prefill_req_count=self.num_prefill_requests, + decode_req_count=self.num_decode_requests, + ) self.batch_dimensions = batch_dimensions best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, strict=self.is_hybrid_model, decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), + explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, ) self._using_cuda_graph_this_step = best_graph is not None - # Track dummy forward history (last 5 steps) - is_dummy_forward = ep_dummy_batch_dimensions is not None - self._dummy_forward_history.append({ - 'is_dummy': is_dummy_forward, - 'ep_dims': str(ep_dummy_batch_dimensions) if is_dummy_forward else None, - 'active_tokens': self.active_token_count, - 'total_reqs': self.total_request_count, - 'using_graph': self._using_cuda_graph_this_step, - 'best_graph': str(best_graph) if best_graph else None, - }) - if len(self._dummy_forward_history) > 5: - self._dummy_forward_history.pop(0) - if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph - if ep_dummy_batch_dimensions is not None: - self.total_request_count = ep_dummy_batch_dimensions.prefill_req_count + \ - ep_dummy_batch_dimensions.decode_req_count - # Zero out request_query_lengths to prevent stale data from causing - # out of bounds memory accesses in last_token_logits. - # When we move finished requests to the right, we never - # zero out their request lengths. - self.request_query_lengths[0:self.total_request_count].fill_(0) else: - if ep_dummy_batch_dimensions is not None: - return - padded_token_count = self.round_up_tokens(self.active_token_count) if self.is_decode_only(): padded_token_count = min( @@ -1271,7 +1315,6 @@ def initialize_attention_state( token_count=padded_token_count, prefill_req_count=padded_prefill_req_count, decode_req_count=padded_decode_req_count, - has_explicit_chunked_prefill_req=self.has_explicit_chunked_prefill_req, ) self.padded_active_token_count = self.padded_batch_dimensions.token_count self.padded_active_request_count = self.padded_batch_dimensions.req_count @@ -1302,8 +1345,6 @@ def initialize_attention_state( attn_dimensions = batch_dimensions if self.using_cuda_graph_this_step(): - assert not self.has_explicit_chunked_prefill_req - # Treat some decode requests as prefill requests to fit the cuda graph batch dimension. if batch_dimensions.decode_req_count > self.padded_batch_dimensions.decode_req_count: total_req = batch_dimensions.req_count @@ -1313,7 +1354,6 @@ def initialize_attention_state( token_count=batch_dimensions.token_count, prefill_req_count=adjusted_prefill_req_count, decode_req_count=adjusted_decode_req_count, - has_explicit_chunked_prefill_req=False, ) assert self.active_attn_metadata is not None @@ -1324,7 +1364,7 @@ def initialize_attention_state( batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, ) - torch.cuda.nvtx.range_push("mamba metadata update") + if self.is_hybrid_model: active_mamba_indices_view = self.mamba_metadata.request_to_mamba_state_idx[active_slice] token_to_request_idx_view = self.token_to_request_idx[: self.active_token_count] @@ -1337,9 +1377,14 @@ def initialize_attention_state( cu_seqlens, batch_dimensions=attn_dimensions, padded_batch_dimensions=self.padded_batch_dimensions, + enable_chunked_prefill=self.is_chunked_prefill_enabled(), ) - torch.cuda.nvtx.range_pop() - torch.cuda.nvtx.range_pop() + + if self.moe_enable_routing_replay: + if self.using_cuda_graph_this_step(): + self.moe_routing_metadata.enable_static_buffer_recording() + else: + self.moe_routing_metadata.disable_static_buffer_recording() def reset(self) -> None: """Reset entire context. @@ -1395,9 +1440,9 @@ def reset(self) -> None: # Reset chunked prefill state self.chunked_prefill_request_id = -1 - self.has_explicit_chunked_prefill_req = False self.num_prefill_requests = 0 self._using_cuda_graph_this_step = False + self.is_creating_cuda_graphs = False self.padded_batch_dimensions = InferenceBatchDimensions( token_count=0, prefill_req_count=0, decode_req_count=0 ) @@ -1440,21 +1485,13 @@ def last_token_logits(self, logits: Tensor) -> Tensor: # Last token logits. logits = logits.squeeze(0) - query_lengths_slice = self.request_query_lengths[self.paused_request_count : self.total_request_count] - last_token_idxs = torch.cumsum(query_lengths_slice, dim=0) - 1 - - # Debug check for OOB - max_idx = last_token_idxs.max().item() if last_token_idxs.numel() > 0 else -1 - if max_idx >= logits.shape[0]: - print(f"OOB ERROR: max_idx={max_idx}, logits_dim={logits.shape[0]}") - print(f"query_lengths={query_lengths_slice}") - print(f"paused={self.paused_request_count}, total={self.total_request_count}") - print(f"active_token_count={self.active_token_count}, padded={self.padded_active_token_count}") - print(f"Dummy forward history (last 5 steps):") - for i, h in enumerate(self._dummy_forward_history): - print(f" Step -{len(self._dummy_forward_history)-i}: {h}") - raise RuntimeError(f"last_token_logits OOB: max_idx={max_idx} >= logits_dim={logits.shape[0]}") - + last_token_idxs = ( + torch.cumsum( + self.request_query_lengths[self.paused_request_count : self.total_request_count], + dim=0, + ) + - 1 + ) last_token_logits = logits[last_token_idxs, :] return last_token_logits @@ -1463,9 +1500,12 @@ def check_availability(self, req: DynamicInferenceRequest) -> Tuple[bool, bool, """ Check if the request can be added to the context. """ + # Note that for hybrid models checking the total request count is sufficient + # because we allocate a single set of Mamba state tensors for each request request_can_be_added = ( self.total_request_count < self.max_requests and self.paused_request_count == 0 ) + request_tokens_can_be_added = ( self.active_token_count + req.remaining_prompt_length <= self.max_tokens ) @@ -1671,6 +1711,12 @@ def get_index_of_chunked_prefill_request(self) -> int: """ return torch.where(self.request_ids == self.chunked_prefill_request_id)[0][0] + def is_chunked_prefill_enabled(self) -> bool: + """Returns whether chunked prefill is enabled.""" + if self.is_hybrid_model: + return self.enable_chunked_prefill and not self.is_creating_cuda_graphs + return self.enable_chunked_prefill + def release_memory_blocks_from_request_indexes(self, request_indexes) -> None: """Release memory blocks used by the given request idxs. @@ -1931,7 +1977,6 @@ def update_requests(self, active_requests_mask: Tensor, new_tokens: Tensor) -> T active_requests_mask[-1] = ( 1 # must keep this, next iteration will add a new chunk to it ) - self.has_explicit_chunked_prefill_req = False active_request_count = (active_requests_mask == 1).sum().item() finished_request_count = (active_requests_mask == 0).sum().item() diff --git a/megatron/core/inference/engines/async_zmq_communicator.py b/megatron/core/inference/engines/async_zmq_communicator.py index 155cb6d002f..7076bb283bd 100644 --- a/megatron/core/inference/engines/async_zmq_communicator.py +++ b/megatron/core/inference/engines/async_zmq_communicator.py @@ -85,7 +85,7 @@ async def all_reduce_max(self, local_val: int) -> int: msg = self.gather_sock.recv(flags=zmq.NOBLOCK) values.append(struct.unpack('!i', msg)[0]) except zmq.Again: - await asyncio.sleep(0.0001) # Yield to event loop + await asyncio.sleep(0.001) # Yield to event loop max_val = max(values) self.bcast_sock.send(struct.pack('!i', max_val)) @@ -100,7 +100,7 @@ async def all_reduce_max(self, local_val: int) -> int: msg = self.bcast_sock.recv(flags=zmq.NOBLOCK) return struct.unpack('!i', msg)[0] except zmq.Again: - await asyncio.sleep(0.0001) # Yield to event loop + await asyncio.sleep(0.001) # Yield to event loop def close(self): """ diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 8eceac104ba..f56e5b1c761 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -25,12 +25,14 @@ AbstractModelInferenceWrapper, ) from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding, set_is_cuda_graphed_iteration_for_ep_inference +from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding from megatron.core.models.multimodal.llava_model import LLaVAModel +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.moe.moe_layer import BaseMoELayer +from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.transformer.utils import set_model_to_sequence_parallel -from megatron.core.utils import get_asyncio_loop, get_model_config, unwrap_model +from megatron.core.utils import get_asyncio_loop, get_model_config, get_pg_size, unwrap_model try: import transformer_engine as te # pylint: disable=unused-import @@ -494,7 +496,7 @@ def unpad_input_prompt_tokens( def _dynamic_step_context_init( self, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, - ep_dummy_batch_dimensions: Optional[InferenceBatchDimensions] = None, + is_dummy_forward: bool = False, ): """Initializes the inference context for dynamic batching. @@ -507,8 +509,6 @@ def _dynamic_step_context_init( input_ids (Tensor): The active input IDs. position_ids (Tensor): The active position IDs. """ - is_dummy_forward = ep_dummy_batch_dimensions is not None - context = self.inference_wrapped_model.inference_context active_request_slice = slice(context.paused_request_count, context.total_request_count) @@ -517,8 +517,7 @@ def _dynamic_step_context_init( model_config = get_model_config(unwrapped_model) # Initialize attention state. - context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions, - ep_dummy_batch_dimensions=ep_dummy_batch_dimensions) + context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions) # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels @@ -528,27 +527,17 @@ def _dynamic_step_context_init( moe_pad_experts_for_cuda_graph_inference = ( self.model_config.moe_pad_experts_for_cuda_graph_inference ) - is_inference_optimized = self.model_config.transformer_impl == "inference_optimized" - if is_inference_optimized: - assert not moe_pad_experts_for_cuda_graph_inference, ( - "moe_pad_experts_for_cuda_graph_inference cannot be True when " - "transformer_impl is 'inference_optimized'" - ) if moe_pad_experts_for_cuda_graph_inference: if context.using_cuda_graph_this_step(): capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk set_decode_expert_padding(unwrapped_model, True, capacity_factor=capacity_factor) else: set_decode_expert_padding(unwrapped_model, False) - - if is_inference_optimized and model_config.expert_model_parallel_size > 1: - set_is_cuda_graphed_iteration_for_ep_inference(unwrapped_model, context.using_cuda_graph_this_step()) # initialize symmetric memory if needed if model_config.transformer_impl == "inference_optimized": context.maybe_initialize_symmetric_memory() - if nccl_all_reduce_for_prefill and symmetric_ar_type is not None: if context.is_decode_only(): # Turn on symmetric all reduce when in decode mode @@ -688,6 +677,66 @@ def _dynamic_step_log_probs_bookkeeping(self) -> Tuple[bool, bool]: return return_log_probs.any(), top_n_log_probs.any() + def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]: + """Collect and map routing indices per request for MoE router recording. + + This method retrieves recorded routing decisions and maps them to individual + requests using the context's request_ids and query_lengths. Uses the context's + routing_metadata when available (which handles CUDA graph static buffers automatically). + Must be called while context attributes are still valid (before request transitions). + + Returns: + Optional[Dict[int, Tensor]]: A dictionary mapping request_id to a tensor of + shape [num_tokens, num_layers, topk]. Returns None if routing replay is + disabled or no routing data was recorded. + """ + config = self.inference_wrapped_model.model.config + if not config.moe_enable_routing_replay: + return None + + # Get routing indices - use routing_metadata if available (handles CUDA graph static buffers) + context = self.inference_wrapped_model.inference_context + if context.moe_routing_metadata is None: + return None + + stacked_routing = context.moe_routing_metadata.get_routing_indices() + + if stacked_routing is None: + return None + + # Get active request info from context + active_request_slice = slice(context.paused_request_count, context.total_request_count) + active_request_ids = context.request_ids[active_request_slice].tolist() + active_query_lengths = context.request_query_lengths[active_request_slice].tolist() + active_token_count = context.active_token_count + + # Get TP group for all-gather if using sequence parallelism + # With sequence parallelism, each TP rank only sees a portion of the tokens, + # so we need to gather routing indices across all TP ranks. + tp_group = self.inference_wrapped_model.tp_group + tp_size = get_pg_size(tp_group) + + # All-gather across TP group if using sequence parallelism (tp_size > 1) + if tp_size > 1 and get_model_config(self.inference_wrapped_model.model).sequence_parallel: + # gather_from_sequence_parallel_region gathers along dim 0 + # [local_token_count, num_layers, topk] -> [global_token_count, num_layers, topk] + stacked_routing = gather_from_sequence_parallel_region(stacked_routing, group=tp_group) + + # Slice to real tokens (remove CUDA padding) + stacked_routing = stacked_routing[:active_token_count] + + # Split by request along token dimension + # stacked_routing has shape [active_token_count, num_layers, topk] + routing_splits = stacked_routing.split(active_query_lengths, dim=0) + + # Map to request IDs + routing_indices_per_request = {} + for req_id, routing_split in zip(active_request_ids, routing_splits): + # routing_split has shape [num_tokens_for_request, num_layers, topk] + routing_indices_per_request[req_id] = routing_split + + return routing_indices_per_request + def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]: """Calculate log probs from logits.""" context = self.inference_wrapped_model.inference_context @@ -786,20 +835,10 @@ def _dynamic_step_calculate_top_n_logprobs( def dummy_forward(self): """Perform a dummy forward pass. This is used in expert model parallelism on ranks that do not have any real requests.""" - + context = self.inference_wrapped_model.inference_context - # no requests should exist in the system - # we will only have padding tokens - # and dummy block idxes. - # context.reset() # if no cuda graphs, directly use dummy forward if not context.cuda_graph_batch_dimensions_list: - # initialize symmetric memory if needed - unwrapped_model = unwrap_model(self.inference_wrapped_model.model) - model_config = get_model_config(unwrapped_model) - if model_config.transformer_impl == "inference_optimized": - context.maybe_initialize_symmetric_memory() - return self.inference_wrapped_model.dummy_forward() # attempt to use cuda-graph if possible @@ -807,7 +846,8 @@ def dummy_forward(self): # a dummy cuda graph. input_ids, position_ids = self._dynamic_step_context_init( # try to use the smallest cuda-graph config for dummy forward - ep_dummy_batch_dimensions=min(context.cuda_graph_batch_dimensions_list) + construct_graph_dimensions=min(context.cuda_graph_batch_dimensions_list), + is_dummy_forward=True, ) # _dynamic_step_context_init tries to find a cuda-graph that is compatible @@ -911,8 +951,16 @@ async def async_generate_output_tokens_dynamic_batch( context.padded_active_request_count if context.is_decode_only() else None ) + # Enable routing recording before forward pass if routing replay is enabled + config = self.inference_wrapped_model.model.config + if config.moe_enable_routing_replay: + RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD) + logits = self._dynamic_step_forward_logits(input_ids, position_ids) + # Collect routing indices per request (must be done before context transitions) + routing_indices_per_request = self._router_record_bookkeeping() + # This is the best place to yield control back to event loop. # At this point we have enqueued FW pass GPU kernels asynchronously. # While they are running, we can do other useful CPU work. @@ -944,6 +992,7 @@ async def async_generate_output_tokens_dynamic_batch( "sample": self._sampled_tokens_cuda[:active_request_count], "log_probs": log_probs, "top_n_logprobs": top_n_logprobs, + "routing_indices_per_request": routing_indices_per_request, "cuda_graph_request_count": cuda_graph_request_count, } ret.update(request_bookkeeping) @@ -1157,7 +1206,6 @@ def generate_all_output_tokens_static_batch( moe_pad_experts_for_cuda_graph_inference = ( self.model_config.moe_pad_experts_for_cuda_graph_inference ) - if moe_pad_experts_for_cuda_graph_inference: set_decode_expert_padding(unwrapped_model, False) diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py index 97a509555a3..fdab97eb7b5 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py @@ -110,22 +110,25 @@ async def completions(): ) tasks.append(client.add_request(prompt_tokens, per_req_params)) - start_time = time.perf_counter() + if current_app.config['verbose']: + start_time = time.perf_counter() + try: batch_results = await asyncio.gather(*tasks) except Exception as e: return f"Error during inference: {e}", 500 - logger.info( - f"Batch of {len(tasks)} requests processed in {time.perf_counter() - start_time:.2f}s" - ) + if current_app.config['verbose']: + logging.info( + f"Batch of {len(tasks)} requests processed in " + f"{time.perf_counter() - start_time:.2f}s" + ) # --- 4. Format Response (matching old_completions.py) --- choices = [] request_idx = 0 for record in batch_results: - # for result in record.requests: result = record.merge() full_text = result.generated_text or "" text_output = (prompts_as_strings[request_idx] + full_text) if echo else full_text @@ -148,9 +151,7 @@ async def completions(): list(result.generated_tokens) if result.generated_tokens else [] ) generated_log_probs = getattr(result, 'generated_log_probs', None) or [] - generated_top_n_logprobs = ( - getattr(result, 'generated_top_n_logprobs', None) or [] - ) + generated_top_n_logprobs = getattr(result, 'generated_top_n_logprobs', None) or [] if echo: # When echo=True, include prompt tokens and their logprobs @@ -167,9 +168,7 @@ async def completions(): top_logprobs = None if prompt_top_n_logprobs or generated_top_n_logprobs: top_logprobs = ( - [None] - + list(prompt_top_n_logprobs) - + list(generated_top_n_logprobs) + [None] + list(prompt_top_n_logprobs) + list(generated_top_n_logprobs) ) # Calculate text_offset: cumulative character positions starting from 0 @@ -204,27 +203,18 @@ async def completions(): "top_logprobs": top_logprobs, } - choices.append( - {"index": request_idx, "text": text_output, "logprobs": logprobs_data} - ) + choices.append({"index": request_idx, "text": text_output, "logprobs": logprobs_data}) + if result.routing_indices is not None: + choices[-1]["moe_topk_indices"] = result.routing_indices.tolist() + prompt_length = len(result.prompt_tokens) if result.prompt_tokens is not None else 0 + if prompt_length: + choices[-1]["prompt_moe_topk_indices"] = result.routing_indices[ + :prompt_length + ].tolist() + request_idx += 1 - prompt_tokens_total = sum(len(p) for p in prompts_as_tokens) - completion_tokens_total = sum( - len(result.generated_tokens) - for record in batch_results - for result in record.requests - if result.generated_tokens is not None - ) - - return jsonify({ - "choices": choices, - "usage": { - "prompt_tokens": prompt_tokens_total, - "completion_tokens": completion_tokens_total, - "total_tokens": prompt_tokens_total + completion_tokens_total, - }, - }) + return jsonify({"choices": choices}) except ImportError as e: logger.warning(f"Could not import flask: {e}") diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 6ff7001d63c..ef67983d4cf 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -16,6 +16,7 @@ from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding from megatron.core.enums import Fp8Recipe from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.fp4_utils import get_fp4_context from megatron.core.fp8_utils import get_fp8_context from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams @@ -25,7 +26,7 @@ from megatron.core.transformer import TransformerConfig from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.identity_op import IdentityOp -from megatron.core.transformer.module import MegatronModule, GraphableMegatronModule +from megatron.core.transformer.module import GraphableMegatronModule, MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.transformer.utils import sharded_state_dict_default @@ -42,6 +43,7 @@ class MambaStackSubmodules: attention_layer: Union[ModuleSpec, type] = IdentityOp mlp_layer: Union[ModuleSpec, type] = IdentityOp moe_layer: Union[ModuleSpec, type] = IdentityOp + mtp_block_spec: Optional[ModuleSpec] = None class MambaStack(GraphableMegatronModule, MegatronModule): @@ -85,12 +87,14 @@ def __init__( device=None, dtype=None, pg_collection: ProcessGroupCollection = None, + is_mtp_layer: bool = False, ) -> None: super().__init__(config=config) self.residual_in_fp32 = residual_in_fp32 self.pre_process = pre_process self.post_layer_norm = post_layer_norm self.post_process = post_process + self.is_mtp_layer = is_mtp_layer assert pg_collection is not None, "pg_collection must be provided for MambaStack" @@ -103,24 +107,41 @@ def __init__( self.hybrid_attention_ratio = hybrid_attention_ratio self.hybrid_mlp_ratio = hybrid_mlp_ratio self.hybrid_override_pattern = hybrid_override_pattern + self.pg_collection = pg_collection + + # For MTP layers, always use pattern length (config.num_layers is for main decoder) + if self.is_mtp_layer: + num_layers_for_allocation = len(self.hybrid_override_pattern) + else: + num_layers_for_allocation = ( + self.config.num_layers + if self.config.num_layers is not None + else len(self.hybrid_override_pattern) + ) self.layer_type_list = allocate_layers( - self.config.num_layers, + num_layers_for_allocation, self.hybrid_attention_ratio, self.hybrid_mlp_ratio, self.hybrid_override_pattern, + silent=self.is_mtp_layer, ) pp_layer_offset = 0 - if self.pp_group.size() > 1: + if self.pp_group.size() > 1 and not self.is_mtp_layer: pp_layer_offset, self.layer_type_list = self._select_layers_for_pipeline_parallel( self.layer_type_list ) - + # Build main decoder layers using shared layer builder self.layers = nn.ModuleList() for i, layer_type in enumerate(self.layer_type_list): - fp8_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) - with fp8_init_context: + if self.config.fp8: + quant_init_context = get_fp8_context(self.config, i + pp_layer_offset, is_init=True) + elif self.config.fp4: + quant_init_context = get_fp4_context(self.config, i + pp_layer_offset, is_init=True) + else: + quant_init_context = nullcontext() + with quant_init_context: if layer_type == LayerSymbols.MAMBA: layer = build_module( submodules.mamba_layer, @@ -137,9 +158,10 @@ def __init__( config=self.config, layer_number=i + 1, pg_collection=pg_collection, + is_mtp_layer=is_mtp_layer, ) elif layer_type == LayerSymbols.MLP: - # Transformer layers apply their own pp_layer_offset + # MLP layers apply their own pp_layer_offset layer = build_module( submodules.mlp_layer, config=self.config, @@ -147,7 +169,7 @@ def __init__( pg_collection=pg_collection, ) elif layer_type == LayerSymbols.MOE: - # Transformer layers apply their own pp_layer_offset + # MoE layers apply their own pp_layer_offset layer = build_module( submodules.moe_layer, config=self.config, @@ -170,15 +192,53 @@ def __init__( ) def _select_layers_for_pipeline_parallel(self, layer_type_list): - num_layers_per_pipeline_rank = self.config.num_layers // self.pp_group.size() - assert self.config.virtual_pipeline_model_parallel_size is None, ( "The Mamba hybrid model does not currently support " "virtual/interleaved pipeline parallelism" ) - offset = self.pp_group.rank() * num_layers_per_pipeline_rank - selected_list = layer_type_list[offset : offset + num_layers_per_pipeline_rank] + pp_rank = self.pp_group.rank() + pp_size = self.pp_group.size() + + num_layers_in_first = self.config.num_layers_in_first_pipeline_stage + num_layers_in_last = self.config.num_layers_in_last_pipeline_stage + + if num_layers_in_first is not None or num_layers_in_last is not None: + # Uneven pipeline parallelism: mirror the logic in + # get_transformer_layer_offset so that MambaStack and + # TransformerLayer agree on layer placement. + first = 0 if num_layers_in_first is None else num_layers_in_first + last = 0 if num_layers_in_last is None else num_layers_in_last + middle_num_layers = self.config.num_layers - first - last + + middle_pipeline_stages = pp_size - sum( + 1 for x in (num_layers_in_first, num_layers_in_last) if x is not None + ) + + if middle_pipeline_stages > 0: + layers_per_middle = middle_num_layers // middle_pipeline_stages + else: + layers_per_middle = 0 + + is_first_stage = num_layers_in_first is not None and pp_rank == 0 + is_last_stage = num_layers_in_last is not None and pp_rank == pp_size - 1 + + if is_first_stage: + offset = 0 + num_layers_this_rank = first + elif is_last_stage: + offset = self.config.num_layers - last + num_layers_this_rank = last + else: + middle_rank = pp_rank if num_layers_in_first is None else pp_rank - 1 + offset = middle_rank * layers_per_middle + first + num_layers_this_rank = layers_per_middle + else: + num_layers_per_pipeline_rank = self.config.num_layers // pp_size + offset = pp_rank * num_layers_per_pipeline_rank + num_layers_this_rank = num_layers_per_pipeline_rank + + selected_list = layer_type_list[offset : offset + num_layers_this_rank] return offset, selected_list @@ -206,14 +266,15 @@ def _should_call_local_cudagraph(self, *args, **kwargs): """ Check if we should call the local cudagraph path. """ - if not self.training and ( - hasattr(self, 'cudagraph_manager') + if ( + not self.training + and hasattr(self, 'cudagraph_manager') and kwargs['attention_mask'] is None and ( kwargs.get('inference_context') is not None or kwargs.get('inference_params') is not None ) - and CudaGraphScope.full_iteration in self.config.cuda_graph_scope + and CudaGraphScope.full_iteration_inference in self.config.cuda_graph_scope ): if kwargs['inference_context'].is_static_batching(): using_cuda_graph = kwargs['inference_context'].is_decode_only() @@ -308,16 +369,29 @@ def forward( # control which layer will be fp8 or bf16 use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed + use_fp4_context = self.config.fp4 is not None outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() + if use_inner_fp8_context: + + def get_inner_quant_context(config, layer_number): + return get_fp8_context(config, layer_number) + + elif use_fp4_context: + + def get_inner_quant_context(config, layer_number): + return get_fp4_context(config, layer_number) + + else: + + def get_inner_quant_context(config, layer_number): + return nullcontext() + with outer_fp8_context: for layer in self.layers: - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - 1) - if use_inner_fp8_context - else nullcontext() - ) - with inner_fp8_context: + # Layers have 1-indexed layer numbers attribute. + inner_quant_context = get_inner_quant_context(self.config, layer.layer_number - 1) + with inner_quant_context: if isinstance(layer, TransformerLayer): hidden_states, _ = layer( hidden_states=hidden_states, @@ -328,7 +402,7 @@ def forward( packed_seq_params=packed_seq_params, padding_mask=padding_mask, ) - else: # MambaLayer + else: # MambaLayer, Expert, or MLP hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask, @@ -348,7 +422,7 @@ def forward( # Ensure that the tensor passed between pipeline parallel stages is # viewless. See related notes in TransformerBlock and TransformerLayer - output = make_viewless_tensor( + hidden_states = make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 1247eb2f5f0..019c6fef396 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -34,6 +34,7 @@ from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.torch_norm import LayerNormBuilder from megatron.core.typed_torch import apply_module, not_none from megatron.core.utils import ( deprecate_inference_params, @@ -219,8 +220,8 @@ class SelfAttentionSubmodules: linear_qkv: LinearQkvBuilder core_attention: CoreAttentionBuilder linear_proj: Union[ModuleSpec, type] = None - q_layernorm: Union[ModuleSpec, type] = None - k_layernorm: Union[ModuleSpec, type] = None + q_layernorm: LayerNormBuilder | None = None + k_layernorm: LayerNormBuilder | None = None @dataclass @@ -1278,8 +1279,7 @@ def __init__( ) if submodules.q_layernorm is not None: - self.q_layernorm = build_module( - submodules.q_layernorm, + self.q_layernorm = submodules.q_layernorm( hidden_size=self.hidden_size_per_attention_head, config=self.config, eps=self.config.layernorm_epsilon, @@ -1288,8 +1288,7 @@ def __init__( self.q_layernorm = None if submodules.k_layernorm is not None: - self.k_layernorm = build_module( - submodules.k_layernorm, + self.k_layernorm = submodules.k_layernorm( hidden_size=self.hidden_size_per_attention_head, config=self.config, eps=self.config.layernorm_epsilon, @@ -1475,10 +1474,10 @@ def get_query_key_value_tensors( query = query[:, :, idx * size : (idx + 1) * size, :] if self.q_layernorm is not None: - query = self.q_layernorm(query) + query = apply_module(self.q_layernorm)(query) if self.k_layernorm is not None: - key = self.k_layernorm(key) + key = apply_module(self.k_layernorm)(key) if self.config.test_mode: self.run_realtime_tests() @@ -1697,4 +1696,4 @@ def get_query_key_value_tensors( ) query = query.view(*new_tensor_shape) - return query, key, value \ No newline at end of file + return query, key, value diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 7870ed2fb6f..48a023e6ddc 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -242,7 +242,7 @@ def _check_supported_type(meta): } assert meta.type in _SUPPORTED_TYPES or is_dataclass( meta.value - ), f"Cudagraphs recieved an arg of type {meta.type} which is not supported." + ), f"Cudagraphs received an arg of type {meta.type} which is not supported." def _determine_if_first_last_layer_of_this_vp_chunk(base_module): @@ -1228,12 +1228,6 @@ def _replace_with_meta(arg): return out[0] return tuple(out) - # def _get_cached_parameters_set_for_inference(self): - # """Return cached parameters for inference mode.""" - # if not hasattr(self, '_cached_parameters_set_for_inference'): - # self._cached_parameters_set_for_inference = tuple(self.parameters()) - # return self._cached_parameters_set_for_inference - def replay_graph_capture(self, is_first_microbatch, args, kwargs): """Replay the fwd cuda graph with autograd.""" @@ -1246,10 +1240,8 @@ def replay_graph_capture(self, is_first_microbatch, args, kwargs): raise AssertionError(error_msg) inp_tensors = self.get_tensors(args, kwargs, check_types=False) - is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context'] - if self.grad_enabled or not is_inference_mode: + if self.grad_enabled: func_args = inp_tensors + self.params_to_backprop - assert kwargs["inference_context"] is None else: func_args = inp_tensors @@ -1562,9 +1554,7 @@ def __call__(self, megatron_module, args, kwargs): else: if is_inference_mode: # Inference generation mode creates graphs immediately - runner = self.get_cudagraph_runner(megatron_module, args, kwargs, True) - if not runner.fwd_graph_recorded: # Reuse graph input-output buffers for inference @@ -1595,12 +1585,12 @@ def __call__(self, megatron_module, args, kwargs): ) runner.fwd_graph_recorded = True runner.cudagraph_created = True + runner = runner.eval() # Record this to the global execution record _CudagraphGlobalRecord.cudagraph_inference_record.append( (runner, "fwd", args, kwargs) ) - runner = runner.eval() # Now replay the graph out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs) @@ -1748,7 +1738,7 @@ def __init__(self, model, config, seq_length, micro_batch_size, optimizers=[]): callables.append(layer) callables_is_mtp.append(False) for layer_number in range(num_mtp_layers): - layer = chunk_with_decoder.mtp.layers[layer_number].transformer_layer + layer = chunk_with_decoder.mtp.layers[layer_number].mtp_model_layer if _layer_is_graphable(layer, config): num_graphable_layers += 1 callables.append(layer) @@ -1865,7 +1855,7 @@ def _get_layer_static_inputs(layer, chunk_of_the_layer): Get the static inputs for a layer. """ assert layer in chunk_of_the_layer.decoder.layers or any( - layer is mtp_layer.transformer_layer for mtp_layer in chunk_of_the_layer.mtp.layers + layer is mtp_layer.mtp_model_layer for mtp_layer in chunk_of_the_layer.mtp.layers ), "Layer is not in the chunk" def get_rotary_pos_emb(transformer_module, transformer_input): diff --git a/megatron/core/transformer/moe/gpu_resident_ops.py b/megatron/core/transformer/moe/gpu_resident_ops.py deleted file mode 100644 index a137a47c7e0..00000000000 --- a/megatron/core/transformer/moe/gpu_resident_ops.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - -""" -GPU-resident operations for CUDA-graph compatible MoE inference. - -This module provides GPU-resident implementations of AlltoAll and GroupedGEMM -operations that accept device tensors for split sizes, eliminating host -synchronization points required for CUDA graph compatibility. -""" - -from typing import Optional - -import torch - - -def gpu_resident_all_to_all( - process_group, - input_tensor: torch.Tensor, - output_split_sizes: torch.Tensor, - input_split_sizes: torch.Tensor, -) -> torch.Tensor: - """ - GPU-resident AlltoAll that accepts device tensors for split sizes. - - This function eliminates the host synchronization bottleneck present in - the standard torch.distributed.all_to_all by accepting split sizes as - GPU tensors instead of CPU lists. - - Args: - process_group: The process group for communication - input_tensor: [sum(input_split_sizes), ...] tensor to send - output_split_sizes: [world_size] GPU tensor - number of elements to receive from each rank - input_split_sizes: [world_size] GPU tensor - number of elements to send to each rank - - Returns: - output_tensor: [sum(output_split_sizes), ...] received tensor - - Example: - >>> # Instead of CPU lists: - >>> # output_splits = [100, 200, 150] # CPU list - >>> # input_splits = [80, 120, 200] # CPU list - >>> # output = all_to_all(group, input, output_splits, input_splits) - >>> - >>> # Use GPU tensors: - >>> output_splits = torch.tensor([100, 200, 150], device='cuda') # GPU - >>> input_splits = torch.tensor([80, 120, 200], device='cuda') # GPU - >>> output = gpu_resident_all_to_all(group, input, output_splits, input_splits) - - Implementation notes: - - This is a placeholder for your GPU-resident AlltoAll implementation - - The actual implementation should avoid any .item(), .tolist(), or .cpu() calls - - Split sizes must remain on GPU throughout the operation - - Should support CUDA graph capture - """ - # TODO: Replace with actual GPU-resident AlltoAll implementation - # For now, this is a placeholder showing the expected interface - raise NotImplementedError( - "gpu_resident_all_to_all requires a custom implementation. " - "This placeholder shows the expected API: accepts GPU tensors for split sizes." - ) - - -def gpu_resident_grouped_gemm( - input: torch.Tensor, - weights: torch.Tensor, - tokens_per_expert: torch.Tensor, - use_fp8: bool = False, -) -> torch.Tensor: - """ - GPU-resident GroupedGEMM that accepts device tensor for expert splits. - - This function provides a CUDA-graph compatible grouped GEMM by accepting - tokens_per_expert as a GPU tensor and computing offsets on-device. - - Args: - input: [total_tokens, K] input tensor - weights: [num_experts, K, N] or [num_experts*K, N] weight tensor - tokens_per_expert: [num_experts] GPU tensor - token count per expert - use_fp8: Whether to use FP8 computation (if available) - - Returns: - output: [total_tokens, N] output tensor - - Example: - >>> # Instead of CPU tokens_per_expert: - >>> # tokens_per_expert_cpu = tokens_per_expert.cpu() # Sync! - >>> # offs = tokens_per_expert_cpu.cumsum(0).cuda() # Another sync! - >>> # output = torch._grouped_mm(input, weights, offs=offs) - >>> - >>> # Use GPU-resident version: - >>> output = gpu_resident_grouped_gemm(input, weights, tokens_per_expert) - - Implementation notes: - - This is a placeholder for your GPU-resident GroupedGEMM implementation - - Should compute cumsum(tokens_per_expert) on GPU without host sync - - Must keep all tensors GPU-resident throughout - - Should support CUDA graph capture - - Can wrap torch._grouped_mm or use custom kernel - """ - # TODO: Replace with actual GPU-resident GroupedGEMM implementation - # For now, this is a placeholder showing the expected interface - - # Example of what the implementation might look like: - # offs = tokens_per_expert.cumsum(0).to(torch.int32) # No .cuda() needed! - # return torch._grouped_mm(input, weights, offs=offs) - - raise NotImplementedError( - "gpu_resident_grouped_gemm requires a custom implementation. " - "This placeholder shows the expected API: accepts GPU tensor for tokens_per_expert." - ) From 4d654c2d00ab5c319d88e51896b9c1144027febd Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 02:51:23 -0800 Subject: [PATCH 40/92] minor --- .../text_generation_controller.py | 18 +++++++++++++++++- megatron/core/transformer/moe/router.py | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index c7a64c93e9c..2a9c4702838 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -25,7 +25,7 @@ AbstractModelInferenceWrapper, ) from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding +from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding, set_is_cuda_graphed_iteration_for_ep_inference from megatron.core.models.multimodal.llava_model import LLaVAModel from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope @@ -530,6 +530,14 @@ def _dynamic_step_context_init( moe_pad_experts_for_cuda_graph_inference = ( self.model_config.moe_pad_experts_for_cuda_graph_inference ) + + is_inference_optimized = self.model_config.transformer_impl == "inference_optimized" + if is_inference_optimized: + assert not moe_pad_experts_for_cuda_graph_inference, ( + "moe_pad_experts_for_cuda_graph_inference cannot be True when " + "transformer_impl is 'inference_optimized'" + ) + if moe_pad_experts_for_cuda_graph_inference: if context.using_cuda_graph_this_step(): capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk @@ -537,6 +545,9 @@ def _dynamic_step_context_init( else: set_decode_expert_padding(unwrapped_model, False) + if is_inference_optimized and model_config.expert_model_parallel_size > 1: + set_is_cuda_graphed_iteration_for_ep_inference(unwrapped_model, context.using_cuda_graph_this_step()) + # initialize symmetric memory if needed if model_config.transformer_impl == "inference_optimized": context.maybe_initialize_symmetric_memory() @@ -842,6 +853,11 @@ def dummy_forward(self): context = self.inference_wrapped_model.inference_context # if no cuda graphs, directly use dummy forward if not context.cuda_graph_batch_dimensions_list: + # initialize symmetric memory if needed + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + model_config = get_model_config(unwrapped_model) + if model_config.transformer_impl == "inference_optimized": + context.maybe_initialize_symmetric_memory() return self.inference_wrapped_model.dummy_forward() # attempt to use cuda-graph if possible diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 04a32d49667..c02c3b4b130 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -722,7 +722,7 @@ class InferenceTopKRouter(TopKRouter): """ def __init__( - self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None + self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None, is_mtp_layer: bool = False, ) -> None: """Initialize the specialized inference top-k router. From afd2ad8739225fd96c08b39b3f059a15cb72360e Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 16:06:53 -0800 Subject: [PATCH 41/92] remove code from the dummy ep PR --- .../core/inference/batch_dimensions_utils.py | 23 +----- .../inference/contexts/dynamic_context.py | 70 +------------------ .../text_generation_controller.py | 10 +-- 3 files changed, 7 insertions(+), 96 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 1303f61c9d2..1a202c35af5 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -133,7 +133,6 @@ def adjust_batch_dims_for_expert_parallelism( strict: bool, decode_only_cuda_graphs: bool, explicit_chunked_prefill: bool, - cuda_graph_mixed_prefill_count: int, ep_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Optional["InferenceBatchDimensions"]: """Adjusted cuda graph batch dimensions for expert parallelism. @@ -158,7 +157,6 @@ def adjust_batch_dims_for_expert_parallelism( # all reduce local work across expert model parallel group is_non_decode = local_batch_dims.prefill_req_count > 0 - sync_tensor = torch.tensor( [ local_batch_dims.token_count, @@ -195,22 +193,12 @@ def adjust_batch_dims_for_expert_parallelism( adjusted_decode_req_count = ( int(sync_tensor[3].item()) if strict else local_batch_dims.decode_req_count ) - adjusted_token_count = int(sync_tensor[0].item()) - - # When any EP rank has prefill requests (non-strict mode), elevate - # the token count to be >= the smallest prefill/mixed cuda graph. - # This ensures decode-only ranks don't match a fine-grained decode - # graph while prefill ranks match a coarser mixed graph, which would - # produce inconsistent token counts across EP ranks. - if is_any_ep_rank_in_non_decode and not strict: - adjusted_token_count = max(adjusted_token_count, cuda_graph_mixed_prefill_count) adjusted_batch_dim = InferenceBatchDimensions( - token_count=adjusted_token_count, + token_count=int(sync_tensor[0].item()), prefill_req_count=adjusted_prefill_req_count, decode_req_count=adjusted_decode_req_count, ) - return adjusted_batch_dim @@ -372,12 +360,6 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int ): cuda_graph_max_tokens = max_tokens - assert cuda_graph_max_tokens == max_requests, ( - f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests " - f"({max_requests}). This is required for correctly syncing EP ranks: " - f"prefill and decode graph pools must have the same token count granularity." - ) - if num_cuda_graphs != -1: # if -1, no need to adjust. This will be taken care of in # the _calculate_cuda_graph_token_counts function where we will generate @@ -474,7 +456,6 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int def match_graph_config( real_batch_dim: InferenceBatchDimensions, cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], - cuda_graph_mixed_prefill_count: int, strict: bool = False, decode_only_cuda_graphs: bool = False, explicit_chunked_prefill: bool = False, @@ -509,7 +490,6 @@ def match_graph_config( decode_only_cuda_graphs=decode_only_cuda_graphs, explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, - cuda_graph_mixed_prefill_count=cuda_graph_mixed_prefill_count, ) if adjusted_batch_dim is None: @@ -532,5 +512,4 @@ def match_graph_config( return None # then find the best batch dimension best_batch_dim = min(graph_batch_dims_applicable) - return best_batch_dim diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 2dc0fc2efe8..9f7556f1312 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -543,7 +543,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - self.cuda_graph_mixed_prefill_count = inference_config.cuda_graph_mixed_prefill_count self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -1261,83 +1260,28 @@ def num_decode_requests(self) -> int: """ return self.total_request_count - self.paused_request_count - self.num_prefill_requests - def add_dummy_requests_for_expert_parallel_step(self) -> None: - """Minimal context setup so an EP rank with no real requests can replay - an already-captured cuda graph without crashing or corrupting memory. - - This is the fast alternative to add_dummy_requests_for_cudagraph_capture - (which goes through the heavyweight add_dummy_requests_parallel path). - - We setup minimal state such the initialize_attention_state and the forward - pass can run without error. - - """ - smallest_cuda_graph_dimensions = min(self.cuda_graph_batch_dimensions_list) - # the smallest cuda graph is decode only. - assert smallest_cuda_graph_dimensions.prefill_req_count == 0 - - N = smallest_cuda_graph_dimensions.decode_req_count - dummy_block_idx = self.block_allocator.dummy_block_idx - - # 1. Request counts and token count (decode-only: 1 token per request). - self.total_request_count = N - self.active_token_count = N - self.num_prefill_requests = 0 - - # 2. Per-request state consumed by mha_metadata.update(). - self.request_query_lengths[0:N].fill_(1) - self.request_kv_length_offsets[0:N].fill_(0) - self.request_to_kv_block_ids[0:N, 0] = dummy_block_idx - - # 3. Token-level state consumed by the triton KV append kernel. - self.token_to_block_idx[0:N] = dummy_block_idx - self.token_to_local_position_within_kv_block[0:N] = 0 - - if self.is_hybrid_model: - # 4. token_to_request_idx: needed by mamba_metadata.update() for hybrid models. - self.token_to_request_idx[0:N] = torch.arange( - 0, N, device=self.token_to_request_idx.device, dtype=self.token_to_request_idx.dtype - ) - - # 5. Mamba state: allocate slots for dummy requests. - self.mamba_metadata.request_to_mamba_state_idx[0:N] = ( - self.mamba_metadata.batch_allocate_slots(N) - ) - def initialize_attention_state( - self, - *, - construct_graph_dimensions: Optional[InferenceBatchDimensions] = None, - is_expert_parallel_dummy_cuda_graph_step: bool = False, + self, *, construct_graph_dimensions: Optional[InferenceBatchDimensions] = None ) -> None: """Initialize attention state so that every layer can use it. Args: construct_graph_dimensions (Optional[InferenceBatchDimensions]): The graph config to use for constructing the cuda graphs. - is_expert_parallel_dummy_cuda_graph_step (bool): Whether this is a dummy expert model parallel step. Return: None. """ self.is_creating_cuda_graphs = construct_graph_dimensions is not None - assert not ( - self.is_creating_cuda_graphs and is_expert_parallel_dummy_cuda_graph_step - ), "Dummy expert model parallel steps should not be creating cuda graphs." # If in CUDA graph creation mode, add dummy requests for CUDA graph capture - if is_expert_parallel_dummy_cuda_graph_step: - self.add_dummy_requests_for_expert_parallel_step() - else: - if self.is_creating_cuda_graphs: - self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) + if self.is_creating_cuda_graphs: + self.add_dummy_requests_for_cudagraph_capture(construct_graph_dimensions) batch_dimensions = InferenceBatchDimensions( token_count=self.active_token_count, prefill_req_count=self.num_prefill_requests, decode_req_count=self.num_decode_requests, ) - self.batch_dimensions = batch_dimensions - best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, @@ -1345,17 +1289,9 @@ def initialize_attention_state( decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, - cuda_graph_mixed_prefill_count=self.cuda_graph_mixed_prefill_count, ) self._using_cuda_graph_this_step = best_graph is not None - if is_expert_parallel_dummy_cuda_graph_step and not self.using_cuda_graph_this_step(): - # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config - # could not find a compatible cuda graph for the dummy forward step. - # Now, we need not do the remaining setup. The controller - # will directly call the model forward pass with a single token. - return - if self.using_cuda_graph_this_step(): self.padded_batch_dimensions = best_graph else: diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 2a9c4702838..440a9d198b1 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -517,10 +517,7 @@ def _dynamic_step_context_init( model_config = get_model_config(unwrapped_model) # Initialize attention state. - context.initialize_attention_state( - construct_graph_dimensions=construct_graph_dimensions, - is_expert_parallel_dummy_cuda_graph_step=is_dummy_forward, - ) + context.initialize_attention_state(construct_graph_dimensions=construct_graph_dimensions) # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels @@ -530,14 +527,12 @@ def _dynamic_step_context_init( moe_pad_experts_for_cuda_graph_inference = ( self.model_config.moe_pad_experts_for_cuda_graph_inference ) - is_inference_optimized = self.model_config.transformer_impl == "inference_optimized" if is_inference_optimized: assert not moe_pad_experts_for_cuda_graph_inference, ( "moe_pad_experts_for_cuda_graph_inference cannot be True when " "transformer_impl is 'inference_optimized'" ) - if moe_pad_experts_for_cuda_graph_inference: if context.using_cuda_graph_this_step(): capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk @@ -865,7 +860,8 @@ def dummy_forward(self): # a dummy cuda graph. input_ids, position_ids = self._dynamic_step_context_init( # try to use the smallest cuda-graph config for dummy forward - is_dummy_forward=True + construct_graph_dimensions=min(context.cuda_graph_batch_dimensions_list), + is_dummy_forward=True, ) # _dynamic_step_context_init tries to find a cuda-graph that is compatible From bf8f546298d8e44d9be289b7fc5f9d1c8955f8d6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 16:08:51 -0800 Subject: [PATCH 42/92] restore utils.py --- megatron/inference/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 447d7290acc..92d153755fe 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -303,7 +303,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): track_paused_request_events=args.inference_dynamic_batching_track_paused_request_events, enable_chunked_prefill=args.enable_chunked_prefill, metrics_writer=metrics_writer, - logging_step_interval=args.inference_logging_step_interval + logging_step_interval=args.inference_logging_step_interval, ) From c4091bd6418585d67f33e9e108d09c8932031e41 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 16:13:59 -0800 Subject: [PATCH 43/92] remove torch grouped gemm kernels: we will add them in another PR --- .../core/transformer/moe/inference_kernels.py | 362 ------------------ .../moe/token_dispatcher_inference.py | 5 - 2 files changed, 367 deletions(-) delete mode 100644 megatron/core/transformer/moe/inference_kernels.py diff --git a/megatron/core/transformer/moe/inference_kernels.py b/megatron/core/transformer/moe/inference_kernels.py deleted file mode 100644 index 430482b55a2..00000000000 --- a/megatron/core/transformer/moe/inference_kernels.py +++ /dev/null @@ -1,362 +0,0 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - -""" -Triton kernels for MoE inference optimizations. -""" - -import torch -import triton -import triton.language as tl - - -@triton.jit -def shift_and_mark_indices_kernel( - topk_indices_ptr, # Input: [num_tokens, topk] - shifted_indices_ptr, # Output: [num_tokens, topk] - num_tokens: tl.constexpr, - topk: tl.constexpr, - local_start: tl.constexpr, # First local expert index - local_end: tl.constexpr, # Last local expert index - sentinel: tl.constexpr, # num_local_experts (sentinel for invalid) - BLOCK_SIZE: tl.constexpr, -): - """ - Shifts topk indices to local coordinate system and marks invalid indices. - - For each index: - - If index in [local_start, local_end]: shift to 0-based (index - local_start) - - Otherwise: mark as sentinel value - """ - # Each program handles one block of elements - pid = tl.program_id(0) - - # Calculate total elements - num_elements = num_tokens * topk - - # Process BLOCK_SIZE elements per program - offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = offset < num_elements - - # Load indices - indices = tl.load(topk_indices_ptr + offset, mask=mask, other=0) - - # Check if index is in local range - is_valid = (indices >= local_start) & (indices <= local_end) - - # Shift valid indices, mark invalid with sentinel - shifted = tl.where(is_valid, indices - local_start, sentinel) - - # Store result - tl.store(shifted_indices_ptr + offset, shifted, mask=mask) - - -def shift_topk_indices( - topk_indices: torch.Tensor, - local_start: int, - local_end: int, - num_local_experts: int, -) -> torch.Tensor: - """ - Shift topk indices to local coordinate system using Triton kernel. - - Args: - topk_indices: [num_tokens, topk] tensor of expert indices - local_start: First local expert global index - local_end: Last local expert global index - num_local_experts: Number of local experts - - Returns: - shifted_indices: [num_tokens, topk] with local indices or sentinel - """ - num_tokens, topk = topk_indices.shape - shifted_indices = torch.empty_like(topk_indices) - - num_elements = num_tokens * topk - BLOCK_SIZE = 1024 - grid = lambda meta: (triton.cdiv(num_elements, meta['BLOCK_SIZE']),) - - shift_and_mark_indices_kernel[grid]( - topk_indices, - shifted_indices, - num_tokens=num_tokens, - topk=topk, - local_start=local_start, - local_end=local_end, - sentinel=num_local_experts, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return shifted_indices - - -@triton.jit -def permute_and_count_kernel( - # Input tensors - hidden_states_ptr, # [num_tokens, hidden_dim] - probs_ptr, # [num_tokens, topk] - expert_assignments_ptr, # [num_tokens * topk] - local expert index per token-k pair - permutation_ptr, # [num_tokens * topk] - argsort result - # Output tensors - permuted_hidden_ptr, # [max_out, hidden_dim] - permuted_probs_ptr, # [max_out] - tokens_per_expert_ptr, # [num_local_experts] - # Scalars - num_tokens: tl.constexpr, - topk: tl.constexpr, - hidden_dim: tl.constexpr, - num_local_experts: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """ - Permute hidden states and probs according to permutation, count tokens per expert. - - Each program handles one output position. Skips sentinel values (expert == num_local_experts). - """ - # Each program handles one output position - pid = tl.program_id(0) - - # Total elements to process - num_elements = num_tokens * topk - - if pid >= num_elements: - return - - # Load the permutation index - where to read from - perm_idx = tl.load(permutation_ptr + pid) - - # Load the expert index for this position - expert_idx = tl.load(expert_assignments_ptr + perm_idx) - - # Skip if this is a sentinel value - if expert_idx >= num_local_experts: - return - - # Compute source token and k indices - # perm_idx tells us position in flattened [num_tokens * topk] array - token_idx = perm_idx // topk - k_idx = perm_idx % topk - - # Copy hidden state: load from [token_idx, :] and store to [pid, :] - for d in range(0, hidden_dim, BLOCK_SIZE): - offset = d + tl.arange(0, BLOCK_SIZE) - mask = offset < hidden_dim - - hidden_val = tl.load( - hidden_states_ptr + token_idx * hidden_dim + offset, - mask=mask, - other=0.0 - ) - tl.store( - permuted_hidden_ptr + pid * hidden_dim + offset, - hidden_val, - mask=mask - ) - - # Copy prob: load from [token_idx, k_idx] - prob_val = tl.load(probs_ptr + token_idx * topk + k_idx) - tl.store(permuted_probs_ptr + pid, prob_val) - - # Atomically increment tokens_per_expert[expert_idx] - tl.atomic_add(tokens_per_expert_ptr + expert_idx, 1) - - -def permute_tokens_and_probs( - hidden_states: torch.Tensor, - probs: torch.Tensor, - expert_assignments: torch.Tensor, - permutation: torch.Tensor, - num_local_experts: int, - max_tokens: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Permute hidden states and probs, count tokens per expert using Triton kernel. - - Args: - hidden_states: [num_tokens, hidden_dim] - probs: [num_tokens, topk] - expert_assignments: [num_tokens * topk] local expert index per token-k pair - permutation: [num_tokens * topk] argsort result - num_local_experts: Number of local experts - max_tokens: Maximum output size - - Returns: - permuted_hidden: [max_tokens, hidden_dim] - permuted_probs: [max_tokens] - tokens_per_expert: [num_local_experts] - """ - num_tokens, hidden_dim = hidden_states.shape - topk = probs.size(1) - - # Allocate outputs - permuted_hidden = torch.empty( - (max_tokens, hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device - ) - permuted_probs = torch.empty( - max_tokens, - dtype=probs.dtype, - device=probs.device - ) - tokens_per_expert = torch.zeros( - num_local_experts, - dtype=torch.int32, - device=hidden_states.device - ) - - # Launch kernel - one program per output position - num_elements = num_tokens * topk - - # Adapt BLOCK_SIZE to hidden_dim for optimal memory access - # Use next power of 2 for better vectorization - BLOCK_SIZE = triton.next_power_of_2(hidden_dim) - # Cap at reasonable maximum to avoid register pressure - BLOCK_SIZE = min(BLOCK_SIZE, 2048) - - grid = (num_elements,) - - permute_and_count_kernel[grid]( - hidden_states, - probs, - expert_assignments, - permutation, - permuted_hidden, - permuted_probs, - tokens_per_expert, - num_tokens=num_tokens, - topk=topk, - hidden_dim=hidden_dim, - num_local_experts=num_local_experts, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return permuted_hidden, permuted_probs, tokens_per_expert - - -@triton.jit -def unpermute_and_combine_kernel( - # Input tensors - permuted_hidden_ptr, # [max_out, hidden_dim] - expert outputs - permutation_ptr, # [num_tokens * topk] - argsort result (forward permutation) - expert_assignments_ptr, # [num_tokens * topk] - local expert index per token-k pair - # Output tensor - output_ptr, # [num_tokens, hidden_dim] - unpermuted output - # Scalars - num_tokens: tl.constexpr, - topk: tl.constexpr, - hidden_dim: tl.constexpr, - num_local_experts: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """ - Unpermute expert outputs back to original token positions. - - Each program handles one position in the permutation array: - - Loads permutation[pid] to find source flat_pos (token_idx, k_idx) - - Loads expert output from permuted arrays at position pid - - Atomically accumulates output into output[token_idx] - - Note: Probability weighting is handled by the experts (via moe_apply_probs_on_input), - so this kernel only does unpermutation and accumulation. - """ - # Each program handles one permuted position - pid = tl.program_id(0) - - num_elements = num_tokens * topk - if pid >= num_elements: - return - - # Load source position from permutation - flat_pos = tl.load(permutation_ptr + pid) - - # Compute source token index - token_idx = flat_pos // topk - - # Load expert index to check validity - expert_idx = tl.load(expert_assignments_ptr + flat_pos) - - # Skip if sentinel (not a valid local expert) - if expert_idx >= num_local_experts: - return - - # Process each dimension chunk - for d in range(0, hidden_dim, BLOCK_SIZE): - offset = d + tl.arange(0, BLOCK_SIZE) - mask = offset < hidden_dim - - # Load expert output (already weighted by experts if configured) - hidden_val = tl.load( - permuted_hidden_ptr + pid * hidden_dim + offset, - mask=mask, - other=0.0 - ) - - # Atomically accumulate into output[token_idx] - tl.atomic_add(output_ptr + token_idx * hidden_dim + offset, hidden_val, mask=mask) - - -def unpermute_and_combine( - permuted_hidden: torch.Tensor, - expert_assignments: torch.Tensor, - permutation: torch.Tensor, - num_tokens: int, - topk: int, - num_local_experts: int, -) -> torch.Tensor: - """ - Unpermute expert outputs back to original token order. - - Args: - permuted_hidden: [max_out, hidden_dim] expert outputs (already weighted by experts) - expert_assignments: [num_tokens * topk] local expert index per token-k pair - permutation: [num_tokens * topk] argsort result from dispatch - num_tokens: Number of original tokens - topk: Number of experts per token - num_local_experts: Number of local experts - - Returns: - output: [num_tokens, hidden_dim] unpermuted output - - Note: The expert outputs should already be weighted by routing probabilities - if moe_apply_probs_on_input is enabled in the config. - """ - hidden_dim = permuted_hidden.size(1) - - # Allocate output (zeroed for atomic accumulation) - output = torch.zeros( - (num_tokens, hidden_dim), - dtype=permuted_hidden.dtype, - device=permuted_hidden.device - ) - - # Adapt BLOCK_SIZE to hidden_dim - BLOCK_SIZE = triton.next_power_of_2(hidden_dim) - BLOCK_SIZE = min(BLOCK_SIZE, 2048) - - # Launch kernel - one program per permuted position (same pattern as permute kernel) - num_elements = num_tokens * topk - grid = (num_elements,) - - unpermute_and_combine_kernel[grid]( - permuted_hidden, - permutation, - expert_assignments, - output, - num_tokens=num_tokens, - topk=topk, - hidden_dim=hidden_dim, - num_local_experts=num_local_experts, - BLOCK_SIZE=BLOCK_SIZE, - ) - - return output - - -def launch_fused_permute_and_probs(*args, **kwargs): - """Placeholder for future fused permute kernel.""" - raise NotImplementedError("launch_fused_permute_and_probs not yet implemented") - - -def launch_unpermute_kernel(*args, **kwargs): - """Placeholder for future unpermute kernel.""" - raise NotImplementedError("launch_unpermute_kernel not yet implemented") diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 76ba0d5d733..b3fbd15fa31 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -19,11 +19,6 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.moe.inference_kernels import ( - shift_topk_indices, - permute_tokens_and_probs, - unpermute_and_combine, -) from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep from megatron.core.inference.communication.torch_symm_triton import ( From 7c2b2ff6bc74b4f88f772da9978f098c434b3fe7 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 16:14:56 -0800 Subject: [PATCH 44/92] remove mamba metadata changes --- .../attention_context/mamba_metadata.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/megatron/core/inference/contexts/attention_context/mamba_metadata.py b/megatron/core/inference/contexts/attention_context/mamba_metadata.py index bacaf882944..d7fcf7436a2 100644 --- a/megatron/core/inference/contexts/attention_context/mamba_metadata.py +++ b/megatron/core/inference/contexts/attention_context/mamba_metadata.py @@ -305,25 +305,6 @@ def allocate_slot(self) -> Optional[int]: return mamba_idx - def batch_allocate_slots(self, num_slots: int) -> Optional[torch.Tensor]: - """ - Allocates new slots for the given number of requests in the Mamba state buffers. - - Returns: - torch.Tensor: The indices of the allocated slots. - Returns None if not enough slots are available. - """ - if self.mamba_state_free_slot_count < num_slots: - return None - - # Get free slots - self.mamba_state_free_slot_count -= num_slots - mamba_idx = self.mamba_state_free_slots[ - self.mamba_state_free_slot_count : self.mamba_state_free_slot_count + num_slots - ] - - return mamba_idx - def free_slots(self, request_indices: torch.Tensor) -> None: """ Frees the Mamba state slots associated with the given request indices. From 2631c1a53194d601e5aa5f51bf0a3009a09f2446 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 17:18:41 -0800 Subject: [PATCH 45/92] simplify hybrid spec call --- .../inference/contexts/dynamic_context.py | 1 + megatron/core/models/gpt/moe_module_specs.py | 88 +++++++++++-------- .../core/models/mamba/mamba_layer_specs.py | 10 +-- 3 files changed, 53 insertions(+), 46 deletions(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 9f7556f1312..fcb81e3cf29 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1290,6 +1290,7 @@ def initialize_attention_state( explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, ) + self._using_cuda_graph_this_step = best_graph is not None if self.using_cuda_graph_this_step(): diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 3b02e19bb2d..739889d8fce 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -17,24 +17,20 @@ def get_moe_module_spec( num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False, - inference_optimized: bool = False, ) -> ModuleSpec: - """Helper function to get module spec for MoE - + """Helper function to get module spec for MoE. + + Called by mamba_layer_specs.py for standard (non-inference) MoE specs. + The GPT layer specs call get_moe_module_spec_for_backend directly. + Args: use_te: Whether to use Transformer Engine. num_experts: Number of experts. moe_grouped_gemm: Whether to use grouped GEMM. moe_use_legacy_grouped_gemm: Whether to use legacy grouped GEMM. - inference_optimized: If True, use InferenceMoELayer for optimized inference. """ - # This function is called my mamba_layer_specs.py - # The GPT layer specs directly calls get_moe_module_spec_for_backend - if use_te is not None and use_te: backend: BackendSpecProvider = TESpecProvider() - elif inference_optimized: - backend = InferenceSpecProvider() else: backend = LocalSpecProvider() return get_moe_module_spec_for_backend( @@ -44,7 +40,6 @@ def get_moe_module_spec( moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, ) - def get_moe_module_spec_for_backend( backend: BackendSpecProvider, num_experts: Optional[int] = None, @@ -52,19 +47,9 @@ def get_moe_module_spec_for_backend( moe_use_legacy_grouped_gemm: Optional[bool] = False, use_te_activation_func: bool = False, ) -> ModuleSpec: - """Helper function to get module spec for MoE - - Args: - backend: Backend spec provider (TE or Local). - num_experts: Number of experts. - moe_grouped_gemm: Whether to use grouped GEMM. - moe_use_legacy_grouped_gemm: Whether to use legacy grouped GEMM. - use_te_activation_func: Whether to use TE activation function. - inference_optimized: If True, use InferenceMoELayer for optimized inference. - """ + """Helper function to get module spec for MoE""" assert num_experts is not None - inference_optimized: bool = isinstance(backend, InferenceSpecProvider) - + linear_fc1 = backend.column_parallel_linear() linear_fc2 = backend.row_parallel_linear() activation_func = backend.activation_func() @@ -85,20 +70,47 @@ def get_moe_module_spec_for_backend( # shared experts spec shared_experts = ModuleSpec(module=SharedExpertMLP, submodules=mlp) - # Select MoE layer class based on inference_optimized flag - if inference_optimized: - moe_module_spec = ModuleSpec( - module=InferenceMoELayer, - submodules=MoESubmodules(router=InferenceTopKRouter, - experts=experts, - shared_experts=shared_experts), - metainfo={"fuse_pre_mlp_layernorm": False}, - ) - else: - # MoE module spec - moe_module_spec = ModuleSpec( - module=MoELayer, - submodules=MoESubmodules(experts=experts, shared_experts=shared_experts), - metainfo={"fuse_pre_mlp_layernorm": False}, - ) + # MoE module spec + moe_module_spec = ModuleSpec( + module=MoELayer, + submodules=MoESubmodules(experts=experts, shared_experts=shared_experts), + metainfo={"fuse_pre_mlp_layernorm": False}, + ) return moe_module_spec + + + +def get_inference_optimized_moe_spec() -> ModuleSpec: + """MoE module spec for inference-optimized transformer impl. + + Uses InferenceSpecProvider to select inference-optimized modules: + InferenceMoELayer, InferenceTopKRouter, InferenceGroupedMLP. + + Called by mamba_layer_specs.py and gpt_layer_specs.py. + """ + backend = InferenceSpecProvider() + activation_func = backend.activation_func() + + expert_module, expert_submodule = backend.grouped_mlp_modules(True, False) + if expert_submodule is not None: + expert_submodule.activation_func = activation_func + + experts = ModuleSpec(module=expert_module, submodules=expert_submodule) + shared_experts = ModuleSpec( + module=SharedExpertMLP, + submodules=MLPSubmodules( + linear_fc1=backend.column_parallel_linear(), + linear_fc2=backend.row_parallel_linear(), + activation_func=activation_func, + ), + ) + + return ModuleSpec( + module=InferenceMoELayer, + submodules=MoESubmodules( + router=InferenceTopKRouter, + experts=experts, + shared_experts=shared_experts, + ), + metainfo={"fuse_pre_mlp_layernorm": False}, + ) diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 8bffd7e3285..791b63ad2eb 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -8,7 +8,7 @@ TERowParallelLinear, ) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec +from megatron.core.models.gpt.moe_module_specs import get_inference_optimized_moe_spec, get_moe_module_spec from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules @@ -42,13 +42,7 @@ ) # Inference-optimized MoE spec -moe_inference = get_moe_module_spec( - use_te=False, - num_experts=8, # Can be any positive integer (must not be None). - moe_grouped_gemm=True, - moe_use_legacy_grouped_gemm=False, - inference_optimized=True, -) +moe_inference = get_inference_optimized_moe_spec() # MTP block spec for Mamba - provides norms and projection only. From 27c0f7c1557557919f61356423a1b634ef21c350 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 17:20:20 -0800 Subject: [PATCH 46/92] restore dynamic context --- megatron/core/inference/contexts/dynamic_context.py | 1 - 1 file changed, 1 deletion(-) diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index fcb81e3cf29..9f7556f1312 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1290,7 +1290,6 @@ def initialize_attention_state( explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, ) - self._using_cuda_graph_this_step = best_graph is not None if self.using_cuda_graph_this_step(): From b055cb661661c5138127b35005639792897048b5 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 17:28:11 -0800 Subject: [PATCH 47/92] slight clean up of router --- megatron/core/transformer/moe/router.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index c02c3b4b130..b434eedd8b5 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -795,10 +795,8 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No - probs: Normalized routing probabilities [num_tokens, topk] - top_indices: Selected expert indices [num_tokens, topk] """ - # Maintain float32 expert bias (important for bf16/fp16) - self._maintain_float32_expert_bias() - - if not self.is_cuda_graphed_iteration: + + if self.training or not self.is_cuda_graphed_iteration: return super().forward(input, padding_mask) return self._forward(input, padding_mask) \ No newline at end of file From 3f24597cea046d97a11907869967dd96737a17b3 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 17:42:41 -0800 Subject: [PATCH 48/92] router cleanup --- megatron/core/transformer/moe/moe_utils.py | 17 +++++-- megatron/core/transformer/moe/router.py | 53 ++++++---------------- 2 files changed, 27 insertions(+), 43 deletions(-) diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 47debdd27df..612fecaa58a 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -664,6 +664,7 @@ def topk_routing_with_score_function( expert_bias: Optional[torch.Tensor] = None, fused: bool = False, router_replay: Optional['RouterReplay'] = None, + dense_output: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the routing probabilities and map for top-k selection with score function. @@ -686,14 +687,17 @@ def topk_routing_with_score_function( recorded routing sequence. Defaults to None. + dense_output (bool, optional): If True, return dense tensors [num_tokens, topk] instead of + sparse tensors [num_tokens, num_experts]. Defaults to False. Returns: Tuple[torch.Tensor, torch.Tensor]: - - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing - the routing probabilities for each token to each expert. - - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts] - indicating which experts were selected for each token. True values represent - the selected experts. + When dense_output=False (default): + - routing_probs (torch.Tensor): Shape [num_tokens, num_experts]. + - routing_map (torch.Tensor): Shape [num_tokens, num_experts]. + When dense_output=True: + - probs (torch.Tensor): Shape [num_tokens, topk]. + - top_indices (torch.Tensor): Shape [num_tokens, topk]. """ assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." num_tokens, num_experts = logits.shape @@ -776,6 +780,9 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if scaling_factor: probs = probs * scaling_factor + if dense_output: + return probs, top_indices + if torch.are_deterministic_algorithms_enabled(): # build [num_tokens, num_experts] from [num_tokens, topk] routing_probs = torch.zeros_like(logits) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index b434eedd8b5..418d39c8c68 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -715,10 +715,6 @@ def _save_to_state_dict(self, *args, **kwargs): class InferenceTopKRouter(TopKRouter): """Specialized top-k router optimized for inference with specific constraints. - This router enforces: - - moe_router_num_groups: None (no group-limited routing) - - moe_router_score_function: sigmoid - - moe_router_enable_expert_bias: True """ def __init__( @@ -748,40 +744,21 @@ def set_is_cuda_graphed_iteration(self, set_to: bool): @torch.compile() def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): - logits = self.gating(input) # [num_tokens, num_experts] - - # Apply score function to get scores per expert - if self.score_function == "sigmoid": - # Sigmoid: independent scores per expert - scores = torch.sigmoid(logits.float()).type_as(logits) # [num_tokens, num_experts] - else: # softmax - # Softmax: normalized scores across all experts - scores = torch.softmax(logits.float(), dim=-1).type_as(logits) # [num_tokens, num_experts] - - # Add expert bias for topk selection if enabled (helps with load balancing) - if self.expert_bias is not None: - scores_for_routing = scores + self.expert_bias # [num_experts] broadcasted - else: - scores_for_routing = scores - - # Select top-k experts based on scores (with or without bias) - _, topk_indices = torch.topk(scores_for_routing, k=self.topk, dim=-1) # [num_tokens, topk] - - # Gather the original scores (without bias) for selected experts - topk_probs = torch.gather(scores, dim=-1, index=topk_indices) # [num_tokens, topk] - - # Normalize to get routing probabilities (sum to 1 per token) - if self.topk > 1: - topk_probs = topk_probs / (topk_probs.sum(dim=-1, keepdim=True) + 1e-20) - - # Apply scaling factor if configured - if self.config.moe_router_topk_scaling_factor: - topk_probs = topk_probs * self.config.moe_router_topk_scaling_factor - - # NOTE: Return format differs from parent class for efficiency: - # - Parent: Returns sparse tensors [num_tokens, num_experts] (routing_probs, routing_map) - # - This: Returns dense tensors [num_tokens, topk] (topk_probs, topk_indices) - return topk_probs.squeeze(1), topk_indices.squeeze(1) + logits = self.gating(input) # [num_tokens, 1, num_experts] + logits = logits.squeeze(1) # [num_tokens, num_experts] + # Reuse the shared routing logic with dense output for inference efficiency. + # Returns [num_tokens, topk] instead of sparse [num_tokens, num_experts]. + + probs, top_indices = topk_routing_with_score_function( + logits, + self.topk, + use_pre_softmax=True, + scaling_factor=self.config.moe_router_topk_scaling_factor, + score_function=self.score_function, + expert_bias=self.expert_bias, + dense_output=True, + ) + return probs.squeeze(1), top_indices.squeeze(1) def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): """Simplified forward pass for inference - returns dense tensors only. From 1bbaf82de91a3331a29d58d0fddf994a44931e66 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 17:55:26 -0800 Subject: [PATCH 49/92] more router cleanup --- megatron/core/transformer/moe/router.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 418d39c8c68..ac751988e4d 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -713,8 +713,15 @@ def _save_to_state_dict(self, *args, **kwargs): class InferenceTopKRouter(TopKRouter): - """Specialized top-k router optimized for inference with specific constraints. + """Inference-only top-k router that strips out training-specific overhead. + A stripped-down version of TopKRouter that skips z-loss, auxiliary load + balancing losses, token dropping, and expert bias updates. The _forward() + method is @torch.compile()'d and returns dense [num_tokens, topk] tensors + instead of sparse [num_tokens, num_experts] for CUDA graph compatibility. + + Falls back to the parent TopKRouter.forward() for training or + non-CUDA-graphed inference iterations. """ def __init__( @@ -744,18 +751,23 @@ def set_is_cuda_graphed_iteration(self, set_to: bool): @torch.compile() def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): - logits = self.gating(input) # [num_tokens, 1, num_experts] - logits = logits.squeeze(1) # [num_tokens, num_experts] - # Reuse the shared routing logic with dense output for inference efficiency. - # Returns [num_tokens, topk] instead of sparse [num_tokens, num_experts]. + logits = self.gating(input).squeeze(1) # [num_tokens, 1, num_experts] + # Share the routing logic with the parent class to avoid code duplication. + # However, we pass dense_output=True to return dense [num_tokens, topk] tensors + # instead of sparse [num_tokens, num_experts]. + probs, top_indices = topk_routing_with_score_function( logits, self.topk, - use_pre_softmax=True, + use_pre_softmax=self.config.moe_router_pre_softmax, + num_groups=self.config.moe_router_num_groups, + group_topk=self.config.moe_router_group_topk, scaling_factor=self.config.moe_router_topk_scaling_factor, score_function=self.score_function, expert_bias=self.expert_bias, + fused=self.config.moe_router_fusion, + router_replay=self.router_replay, dense_output=True, ) return probs.squeeze(1), top_indices.squeeze(1) From 5607f6f8e75c3c53fb4c41c38a932abc352af27e Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Tue, 24 Feb 2026 18:20:11 -0800 Subject: [PATCH 50/92] absorb inference layer into parent moe layer and more cleanup --- megatron/core/inference/utils.py | 2 +- megatron/core/models/gpt/moe_module_specs.py | 6 +- megatron/core/transformer/moe/moe_layer.py | 101 +++++- .../transformer/moe/moe_layer_inference.py | 206 ----------- .../inference/test_batch_dimension_utils.py | 340 ------------------ 5 files changed, 104 insertions(+), 551 deletions(-) delete mode 100644 megatron/core/transformer/moe/moe_layer_inference.py delete mode 100644 tests/unit_tests/inference/test_batch_dimension_utils.py diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index ad9638272f9..5a2939decc7 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -134,7 +134,7 @@ def set_decode_expert_padding(model, set_to: bool = False, capacity_factor: int def set_is_cuda_graphed_iteration_for_ep_inference(model, set_to: bool): """ Toggle CUDA graph compatibility for expert parallel inference. - This sets a boolean flag in all InferenceMoELayers to indicate whether + This sets a boolean flag in all MoELayers to indicate whether the current iteration is being captured/executed in a CUDA graph. This allows the dispatcher to adjust its behavior for CUDA graph compatibility, Args: diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 739889d8fce..e076ed6a5bf 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -6,7 +6,6 @@ from megatron.core.models.backends import BackendSpecProvider, LocalSpecProvider, InferenceSpecProvider from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules -from megatron.core.transformer.moe.moe_layer_inference import InferenceMoELayer from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.moe.router import InferenceTopKRouter @@ -84,7 +83,8 @@ def get_inference_optimized_moe_spec() -> ModuleSpec: """MoE module spec for inference-optimized transformer impl. Uses InferenceSpecProvider to select inference-optimized modules: - InferenceMoELayer, InferenceTopKRouter, InferenceGroupedMLP. + InferenceTopKRouter, InferenceGroupedMLP. MoELayer detects inference mode + via config.transformer_impl and sets up the inference dispatcher internally. Called by mamba_layer_specs.py and gpt_layer_specs.py. """ @@ -106,7 +106,7 @@ def get_inference_optimized_moe_spec() -> ModuleSpec: ) return ModuleSpec( - module=InferenceMoELayer, + module=MoELayer, submodules=MoESubmodules( router=InferenceTopKRouter, experts=experts, diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 3d9d0b092aa..e892db9c148 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -246,10 +246,57 @@ def __init__( if self.shared_expert_overlap: self.token_dispatcher.set_shared_experts(self.shared_experts) + # Inference-optimized mode setup + if config.transformer_impl == "inference_optimized": + self._setup_inference_mode(pg_collection) + # Cudagraph tensor store for resuming the forward pass from the end of the cudagraph. self.cudagraph_tensor_store = MoECudaGraphTensorStore() self.fwd_execution_map = ["route", "expert_compute", "postprocess"] + + def _setup_inference_mode(self, pg_collection): + """Set up inference-optimized token dispatcher and state. + + Called from __init__ when config.transformer_impl == "inference_optimized". + Creates an InferenceAllGatherTokenDispatcher alongside the standard dispatcher, + which is swapped in during CUDA-graphed forward passes. + """ + from megatron.core.transformer.moe.token_dispatcher_inference import ( + InferenceAllGatherTokenDispatcher, + ) + + assert self.config.moe_token_dispatcher_type == "alltoall", ( + f"Inference-optimized MoE requires 'alltoall' dispatcher, " + f"got '{self.config.moe_token_dispatcher_type}'" + ) + self.is_cuda_graphed_iteration = False + self._inference_token_dispatcher = InferenceAllGatherTokenDispatcher( + self.num_local_experts, + self.local_expert_indices, + config=self.config, + pg_collection=pg_collection, + ) + + def set_is_cuda_graphed_iteration(self, set_to: bool): + """Toggle CUDA-graphed iteration mode on this layer and its router.""" + self.is_cuda_graphed_iteration = set_to + if hasattr(self.router, 'set_is_cuda_graphed_iteration'): + self.router.set_is_cuda_graphed_iteration(set_to) + + def _activate_inference_token_dispatcher(self): + """Swap in the inference-optimized token dispatcher.""" + self._saved_token_dispatcher = self.token_dispatcher + self.token_dispatcher = self._inference_token_dispatcher + self._saved_shared_expert_overlap = self.shared_expert_overlap + self.shared_expert_overlap = False + + def _deactivate_inference_token_dispatcher(self): + """Restore the standard token dispatcher.""" + self.token_dispatcher = self._saved_token_dispatcher + self.shared_expert_overlap = self._saved_shared_expert_overlap + + @maybe_skip_or_early_return_by_cudagraph("route") def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): """Compute token routing for preprocessing. @@ -325,6 +372,9 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso for each expert. It then passes the tokens through the local experts. The output from the experts is preprocessed for the combine step. """ + if not self.training and self.is_cuda_graphed_iteration: + return self._fused_experts_compute(hidden_states, probs) + dispatched_input, tokens_per_expert, permuted_probs = ( self.token_dispatcher.dispatch_postprocess(hidden_states, probs) ) @@ -334,6 +384,43 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso return output, mlp_bias + def _fused_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor): + """FlashInfer fused MoE kernel for CUDA-graphed inference iterations.""" + import flashinfer.fused_moe as fused_moe + from flashinfer.fused_moe.core import ActivationType + + from megatron.core.activations import squared_relu + + assert not self.config.gated_linear_unit, ( + "FlashInfer MoE kernel currently only supports non-gated activations. " + f"Got gated_linear_unit={self.config.gated_linear_unit}" + ) + assert self.config.activation_func == squared_relu, ( + "FlashInfer MoE kernel currently only supports squared_relu activation. " + f"Got activation_func={self.config.activation_func}" + ) + + w1 = self.experts._fc1_weight + w2 = self.experts._fc2_weight + selected_experts = self.token_dispatcher.routing_map + ep_size = utils.get_pg_size(self.ep_group) + ep_rank = utils.get_pg_rank(self.ep_group) + + output = fused_moe.cutlass_fused_moe( + hidden_states, + selected_experts.to(torch.int), + probs.float(), + w1, + w2, + hidden_states.dtype, + quant_scales=None, + activation_type=ActivationType.Relu2, + ep_size=ep_size, + ep_rank=ep_rank, + )[0] + + return output, None + def combine(self, output: torch.Tensor): """Combines expert outputs via communication and adds shared expert output. @@ -393,6 +480,15 @@ def forward( if padding_mask is not None: padding_mask = padding_mask.transpose(0, 1).bool() + # Swap in inference-optimized dispatcher for CUDA-graphed iterations + _use_inference_dispatcher = ( + not self.training + and self.is_cuda_graphed_iteration + and self._inference_token_dispatcher is not None + ) + if _use_inference_dispatcher: + self._activate_inference_token_dispatcher() + # MoE forward: route -> dispatch -> compute -> combine def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): try: @@ -437,7 +533,7 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): return output, mlp_bias - if self.moe_layer_recompute: + if self.moe_layer_recompute and not _use_inference_dispatcher: if self.config.fp8 or self.config.fp4: outputs = te_checkpoint( custom_forward, @@ -455,6 +551,9 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): else: outputs = custom_forward(hidden_states, intermediate_tensors, padding_mask) + if _use_inference_dispatcher: + self._deactivate_inference_token_dispatcher() + return outputs def backward_dw(self, routed_experts: bool = True, shared_experts: bool = False): diff --git a/megatron/core/transformer/moe/moe_layer_inference.py b/megatron/core/transformer/moe/moe_layer_inference.py deleted file mode 100644 index 90ad04c5299..00000000000 --- a/megatron/core/transformer/moe/moe_layer_inference.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - -""" -Inference-optimized MoE Layer with AlltoAll Token Dispatcher. - -This implementation inherits from MoELayer to ensure checkpoint compatibility, -while providing a simplified forward pass optimized for inference: -1. Strips out training-specific code (aux losses, recomputation, backward) -2. Uses a simple, linear forward flow -3. Is designed to be CUDA graph compatible (future work) - -The forward pass follows this flow: - Input [S, B, H] - ↓ Route (router gate → topk selection) - probs, routing_map - ↓ Permute (group tokens by expert) - permuted_tokens [num_selected_tokens, H] - ↓ EP AlltoAll (distribute to expert owners) - global_tokens [tokens_on_this_rank, H] - ↓ TP AllGather (if tp_size > 1) - gathered_tokens - ↓ Sort by local expert (if num_local_experts > 1) - sorted_tokens - ↓ Expert FFN (GroupedGEMM) - expert_output - ↓ Unsort by local expert - unsorted_output - ↓ TP ReduceScatter (if tp_size > 1) - scattered_output - ↓ EP AlltoAll (return to original ranks) - combined_output - ↓ Unpermute (restore original order) - Output [S, B, H] - -Usage: - # Load a trained MoELayer checkpoint directly: - inference_layer = InferenceMoELayer(config, submodules, layer_number, pg_collection) - inference_layer.load_state_dict(trained_moe_layer.state_dict()) - -TODO: Add unit test to verify that InferenceMoELayer.forward() and MoELayer.forward() - have aligned argument signatures (use inspect.signature to compare). -""" - -from typing import Optional - -import torch -import torch.nn.functional as F - -from megatron.core import utils -from megatron.core.activations import squared_relu -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules -from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.transformer.moe.moe_utils import get_default_pg_collection -from megatron.core.transformer.moe.token_dispatcher_inference import InferenceAllGatherTokenDispatcher -import flashinfer.fused_moe as fused_moe -from flashinfer.fused_moe.core import ActivationType - -import logging - -class InferenceMoELayer(MoELayer): - """ - Inference-optimized MoE layer that inherits from MoELayer for checkpoint compatibility. - - This implementation: - - Inherits all weights/submodules from MoELayer (router, experts, token_dispatcher) - - Provides a simplified forward() optimized for inference - - Removes training overhead (aux losses, recomputation, gradient computation) - - Only supports AlltoAll dispatcher (most common for inference) - - Checkpoints trained with MoELayer can be loaded directly. - """ - - def __init__( - self, - config: TransformerConfig, - submodules: Optional[MoESubmodules] = None, - layer_number: Optional[int] = None, - pg_collection: Optional[ProcessGroupCollection] = None, - ): - """ - Initialize the inference MoE layer. - - Args are identical to MoELayer for checkpoint compatibility. - """ - # Initialize parent MoELayer (creates router, experts, token_dispatcher) - if pg_collection is None: - pg_collection = get_default_pg_collection() - - super().__init__( - config=config, - submodules=submodules, - layer_number=layer_number, - pg_collection=pg_collection, - ) - - # Validate dispatcher type - # todo: move this assert to arguments.py or transformer_config.py - if config.moe_token_dispatcher_type != "alltoall": - raise ValueError( - f"InferenceMoELayer only supports 'alltoall' dispatcher, " - f"got '{config.moe_token_dispatcher_type}'" - ) - - self.is_cuda_graphed_iteration = False - self.inference_token_dispatcher = InferenceAllGatherTokenDispatcher( - self.num_local_experts, - self.local_expert_indices, - config=self.config, - pg_collection=pg_collection, - ) - def set_is_cuda_graphed_iteration(self, set_to): - self.is_cuda_graphed_iteration = set_to - self.router.set_is_cuda_graphed_iteration(set_to) - - def activate_inference_token_dispatcher(self): - # replace the token dispatcher with the inference-optimized version - self.old_token_dispatcher = self.token_dispatcher - self.token_dispatcher = self.inference_token_dispatcher - - # disable shared expert overlap during inference as it is not - # supported in InferenceAllGatherTokenDispatcher - self.old_expert_overlap = self.shared_expert_overlap - self.shared_expert_overlap = False - - def deactivate_inference_token_dispatcher(self): - # restore the original token dispatcher - # and shared expert overlap setting - self.token_dispatcher = self.old_token_dispatcher - self.shared_expert_overlap = self.old_expert_overlap - - - # ==================== Simplified Forward Pass ==================== - def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): - """ - """ - if not self.is_cuda_graphed_iteration: - # Note: this will still call InferenceGroupedMLP.forward() - # and therefore optimized cutlass grouped gemms. - return super().forward(hidden_states, padding_mask) - - self.activate_inference_token_dispatcher() - assert self.token_dispatcher is self.inference_token_dispatcher - #logging.info("activated inference token dispatcher") - - forward_pass_output = super().forward(hidden_states, padding_mask) - - self.deactivate_inference_token_dispatcher() - assert self.token_dispatcher is not self.inference_token_dispatcher - #logging.info("deactivated inference token dispatcher") - - return forward_pass_output - - def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor): - """Computes the output of the routed experts on the dispatched tokens. - - This method first post-processes the dispatched input to get permuted tokens - for each expert. It then passes the tokens through the local experts. - The output from the experts is preprocessed for the combine step. - """ - if not self.is_cuda_graphed_iteration: - # todo: can we go down the flashinfer path even if not cuda graphed? - return super().routed_experts_compute(hidden_states, probs) - - # Currently only squared_relu (non-gated) is supported with FlashInfer - assert not self.config.gated_linear_unit, ( - "FlashInfer MoE kernel currently only supports non-gated activations. " - f"Got gated_linear_unit={self.config.gated_linear_unit}" - ) - assert self.config.activation_func == squared_relu, ( - "FlashInfer MoE kernel currently only supports squared_relu activation. " - f"Got activation_func={self.config.activation_func}" - ) - - # Get dtype from input - output_dtype = hidden_states.dtype - - # Get expert weights from self.experts (GroupedMLP) - w1 = self.experts._fc1_weight - w2 = self.experts._fc2_weight - - # Get routing information (stored from route() step) - selected_experts = self.token_dispatcher.routing_map - routing_weights = probs - - # Get EP attributes - ep_size = utils.get_pg_size(self.ep_group) - ep_rank = utils.get_pg_rank(self.ep_group) - - # Call FlashInfer fused MoE kernel with Relu2 (squared ReLU) - output = fused_moe.cutlass_fused_moe( - hidden_states, - selected_experts.to(torch.int), - routing_weights.float(), - w1, - w2, - output_dtype, - quant_scales=None, - activation_type=ActivationType.Relu2, - ep_size=ep_size, - ep_rank=ep_rank, - )[0] - - return output, None - - diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py deleted file mode 100644 index d155bdf6d7f..00000000000 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ /dev/null @@ -1,340 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - -""" -Unit tests for CUDAGraphBatchDimensionBuilder.match_graph_config with expert parallelism. -""" - -import pytest -import torch -import torch.distributed as dist - -from megatron.core import parallel_state as ps -from megatron.core.inference.batch_dimensions_utils import ( - CUDAGraphBatchDimensionBuilder, - InferenceBatchDimensions, -) -from tests.unit_tests.test_utilities import Utils - -BD = InferenceBatchDimensions - -# Common config shared across tests -MAX_REQUESTS = 256 -MAX_TOKENS = 2048 -MAX_SEQ_LEN = 4096 -TP_SIZE = 1 -MIXED_PREFILL_COUNT = 4 - - -def _generate_graphs(num_cuda_graphs, use_non_decode=True): - """Generate cuda graph batch dimensions using the builder.""" - graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( - tp_size=TP_SIZE, - num_cuda_graphs=num_cuda_graphs, - cuda_graph_max_tokens=MAX_REQUESTS, - cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, - max_requests=MAX_REQUESTS, - max_tokens=MAX_TOKENS, - max_sequence_length=MAX_SEQ_LEN, - use_cuda_graphs_for_non_decode_steps=use_non_decode, - ) - return graph_list - - -def _match( - real, graph_list, ep_group, strict=False, decode_only=False, explicit_chunked_prefill=False -): - return CUDAGraphBatchDimensionBuilder.match_graph_config( - real_batch_dim=real, - cuda_graph_batch_dimensions_list=graph_list, - strict=strict, - decode_only_cuda_graphs=decode_only, - explicit_chunked_prefill=explicit_chunked_prefill, - ep_group=ep_group, - cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, - ) - - -def _assert_consistent_across_ranks(result, ep_group): - """Assert that the match result is the same on every EP rank. - - Either all ranks return None, or all ranks return a config with the - same token_count (which is what the all-reduce synchronises). - """ - if result is None: - flag = torch.zeros(1, dtype=torch.int32, device="cuda") - else: - flag = torch.ones(1, dtype=torch.int32, device="cuda") - - # If any rank got None, all must get None; if any rank got a match, all must. - flag_sum = flag.clone() - dist.all_reduce(flag_sum, op=dist.ReduceOp.SUM, group=ep_group) - ep_size = dist.get_world_size(ep_group) - assert ( - flag_sum.item() == 0 or flag_sum.item() == ep_size - ), f"Inconsistent match: {flag_sum.item()}/{ep_size} ranks got a match" - - if result is not None: - tc = torch.tensor([result.token_count], dtype=torch.int32, device="cuda") - tc_max = tc.clone() - tc_min = tc.clone() - dist.all_reduce(tc_max, op=dist.ReduceOp.MAX, group=ep_group) - dist.all_reduce(tc_min, op=dist.ReduceOp.MIN, group=ep_group) - assert ( - tc_max.item() == tc_min.item() - ), f"Token count mismatch across EP ranks: min={tc_min.item()}, max={tc_max.item()}" - - -class TestMatchGraphConfigWithEP: - """Tests for match_graph_config with expert parallelism. - - Uses the world group as the EP group (all 8 GPUs form one EP group). - """ - - def setup_method(self, method): - Utils.initialize_model_parallel( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - expert_model_parallel_size=Utils.world_size, - ) - - def teardown_method(self, method): - Utils.destroy_model_parallel() - - @staticmethod - def _get_ep_group(): - """Return the EP group created by initialize_model_parallel.""" - return ps.get_expert_model_parallel_group() - - # ------------------------------------------------------------------ # - # 1. All ranks same decode batch → consistent match - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_uniform_decode_batch(self, num_cuda_graphs): - """All EP ranks have the same decode-only batch → should all match the same graph.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) - - result = _match(real, graph_list, ep_group=ep_group) - _assert_consistent_across_ranks(result, ep_group) - assert result is not None, "Should find a matching graph for uniform decode batch" - assert result.token_count == 32 - - # ------------------------------------------------------------------ # - # 2. Different token counts across EP ranks → all-reduce takes max - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_varying_decode_token_counts(self, num_cuda_graphs): - """EP ranks have different decode token counts. The all-reduce - should take the max, and all ranks should match the same graph.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - # Each rank gets a different token count: 8, 16, 24, ... - token_count = (rank + 1) * 8 - real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=token_count) - - result = _match(real, graph_list, ep_group=ep_group) - _assert_consistent_across_ranks(result, ep_group) - assert result is not None - assert result.token_count == (ep_group.size() * 8) - - # ------------------------------------------------------------------ # - # 3. decode_only_cuda_graphs=True, some ranks have prefill → all None - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_decode_only_graphs_with_mixed_ranks(self, num_cuda_graphs): - """When decode_only_cuda_graphs=True and at least one EP rank has a - prefill request, ALL ranks should get None (eager mode).""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - # Rank 0 has a mixed batch (prefill + decode), all others decode-only - if rank == 0: - real = BD(token_count=64, prefill_req_count=2, decode_req_count=10) - else: - real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) - - result = _match(real, graph_list, ep_group=ep_group, decode_only=True) - _assert_consistent_across_ranks(result, ep_group) - assert ( - result is None - ), "All ranks should run eager when decode_only=True and some rank has prefill" - - # ------------------------------------------------------------------ # - # 4. explicit_chunked_prefill=True, some ranks prefill → all None - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_explicit_chunked_prefill_with_mixed_ranks(self, num_cuda_graphs): - """When explicit_chunked_prefill=True and some EP rank has prefill, - ALL ranks should get None (eager mode).""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - if rank == 0: - real = BD(token_count=64, prefill_req_count=2, decode_req_count=10) - else: - real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) - - result = _match(real, graph_list, ep_group=ep_group, explicit_chunked_prefill=True) - _assert_consistent_across_ranks(result, ep_group) - assert result is None, "All ranks should run eager with explicit_chunked_prefill" - - # ------------------------------------------------------------------ # - # 5. Mixed prefill graphs with strict matching - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_strict_matching_with_mixed_prefill(self, num_cuda_graphs): - """With strict matching, request counts are synced across EP ranks - via all-reduce. All ranks should still get a consistent result.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - # Varying prefill/decode split across ranks - prefill = min(rank + 1, MIXED_PREFILL_COUNT) - decode = 16 - prefill - real = BD(token_count=64, prefill_req_count=prefill, decode_req_count=decode) - - result = _match(real, graph_list, ep_group=ep_group, strict=True) - _assert_consistent_across_ranks(result, ep_group) - - # ------------------------------------------------------------------ # - # 6. Non-strict matching with mixed prefill - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_non_strict_matching_with_mixed_prefill(self, num_cuda_graphs): - """Non-strict matching: prefill slots can serve decode. Token count - is synced across EP ranks; result must be consistent.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - prefill = min(rank + 1, MIXED_PREFILL_COUNT) - decode = 16 - prefill - real = BD(token_count=64, prefill_req_count=prefill, decode_req_count=decode) - - result = _match(real, graph_list, ep_group=ep_group) - _assert_consistent_across_ranks(result, ep_group) - - # ------------------------------------------------------------------ # - # 7. Mixed decode/prefill across ranks — strict matching - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_mixed_decode_and_prefill_ranks_strict(self, num_cuda_graphs): - """Some EP ranks are pure decode, others have prefill requests. - With strict matching the all-reduce syncs request counts to the - max across ranks. Result must be consistent.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - # Even ranks: pure decode (32 tokens) - # Odd ranks: mixed prefill (64 tokens, 2 prefill + 14 decode) - if rank % 2 == 0: - real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) - else: - real = BD(token_count=64, prefill_req_count=2, decode_req_count=14) - - result = _match(real, graph_list, ep_group=ep_group, strict=True) - _assert_consistent_across_ranks(result, ep_group) - - # ------------------------------------------------------------------ # - # 8. Mixed decode/prefill across ranks — non-strict matching - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_mixed_decode_and_prefill_ranks_non_strict(self, num_cuda_graphs): - """Some EP ranks are pure decode, others have prefill requests. - Non-strict matching only syncs token counts (not request counts). - Result must be consistent.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - # Even ranks: pure decode (32 tokens) - # Odd ranks: mixed prefill (64 tokens, 2 prefill + 14 decode) - if rank % 2 == 0: - real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) - else: - real = BD(token_count=64, prefill_req_count=2, decode_req_count=14) - - result = _match(real, graph_list, ep_group=ep_group) - _assert_consistent_across_ranks(result, ep_group) - - # ------------------------------------------------------------------ # - # 9. All ranks decode-only with decode_only_cuda_graphs → should match - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_decode_only_graphs_all_decode(self, num_cuda_graphs): - """When all EP ranks are decode-only and decode_only_cuda_graphs=True, - a match should be found.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - token_count = (rank + 1) * 4 - real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=token_count) - - result = _match(real, graph_list, ep_group=ep_group, decode_only=True) - _assert_consistent_across_ranks(result, ep_group) - assert result is not None, "All-decode batch with decode_only_cuda_graphs should match" - - # ------------------------------------------------------------------ # - # 10. Real batch exceeds all graphs → None on all ranks - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_oversized_batch_returns_none(self, num_cuda_graphs): - """When the real batch is larger than any available graph, all ranks - should get None.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - - # Token count exceeds MAX_TOKENS on all ranks - real = BD( - token_count=MAX_TOKENS + 100, - prefill_req_count=0, - decode_req_count=min(MAX_TOKENS + 100, MAX_REQUESTS), - ) - - result = _match(real, graph_list, ep_group=ep_group) - _assert_consistent_across_ranks(result, ep_group) - assert result is None, "Oversized batch should not match any graph" - - # ------------------------------------------------------------------ # - # 11. One EP rank has huge batch → all-reduce lifts to max → no match - # ------------------------------------------------------------------ # - @pytest.mark.internal - @pytest.mark.parametrize("num_cuda_graphs", [16, 32, -1]) - def test_one_rank_oversized_forces_no_match(self, num_cuda_graphs): - """If one EP rank has a batch exceeding all graph capacities, the - all-reduce max lifts everyone → no match on any rank.""" - ep_group = self._get_ep_group() - graph_list = _generate_graphs(num_cuda_graphs) - rank = dist.get_rank() - - if rank == 0: - # This rank has a batch that exceeds all graphs - real = BD( - token_count=MAX_TOKENS + 100, - prefill_req_count=0, - decode_req_count=min(MAX_TOKENS + 100, MAX_REQUESTS), - ) - else: - real = BD(token_count=8, prefill_req_count=0, decode_req_count=8) - - result = _match(real, graph_list, ep_group=ep_group) - _assert_consistent_across_ranks(result, ep_group) - assert result is None, "All-reduce max from oversized rank should cause no match" From 834656be35bdf05281742b80d2f9c13f576f9048 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 11:04:39 -0800 Subject: [PATCH 51/92] more cleanup --- .../torch_symm_triton/__init__.py | 2 +- .../torch_symm_triton/collectives.py | 339 +++++++----------- .../moe/token_dispatcher_inference.py | 92 ++--- megatron/core/utils.py | 38 ++ 4 files changed, 199 insertions(+), 272 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index 282c98008f0..586e913541e 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from .collectives import multimem_all_gather, multimem_all_gather_3, multimem_reduce_scatter +from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter from .fused_collectives import fused_multimem_rs_add_norm_ag diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 6c482d39395..eb48dae7d0f 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -25,126 +25,23 @@ from .multimem_asm import ld_128, st_128 from .utils import get_flat_tid, sync_threads +# ── Triton kernels ───────────────────────────────────────────────────────── @triton.jit -def _multimem_all_gather_kernel( - local_ptr, - multicast_ptr, - signal_pad_ptrs, - numel, - byte_offset, - BLOCK_SIZE: tl.constexpr, - NUMEL_PER_THREAD: tl.constexpr, - RANK: tl.constexpr, - WORLD_SIZE: tl.constexpr, -): - """ - Triton kernel to perform multicast all-gather over nvlink using multimem instructions. - - Args: - byte_offset: Byte offset into the multicast buffer where this tensor starts. +def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE): """ - # an all-gather is simply a multicast store operation - # we only need a barrier at the end to ensure visibility of writes - - pid = tl.program_id(axis=0) - tid = get_flat_tid() - - # From this point on, we pretend each element is 128-bit - numel = numel // NUMEL_PER_THREAD - numel_per_rank = tl.cdiv(numel, WORLD_SIZE) - block_start = pid * BLOCK_SIZE - - while block_start < numel_per_rank: - offsets = block_start + tid - mask = offsets < numel_per_rank - - # Each pointer points to a 128-bit bit pack - # byte_offset // 8 -> converts byte offset to uint64 offset - # RANK * numel_per_rank -> brings us to the start of our rank's segment - # offsets -> brings us to the right offset within our rank's segment - # * 2 -> each 128-bit pack is 2 uint64s - multicast_ptrs = ( - multicast_ptr.to(tl.pointer_type(tl.uint64)) + byte_offset // 8 + (RANK * numel_per_rank + offsets) * 2 - ) - local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 - (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) - st_128(multicast_ptrs, x, y, z, w, mask=mask, multicast_op=True) - - block_start += tl.num_programs(axis=0) * BLOCK_SIZE - - sync_threads() - symm_mem_sync( - signal_pad_ptrs, - None, - RANK, - WORLD_SIZE, - hasPreviousMemAccess=True, - hasSubsequentMemAccess=True, - ) + Core all-gather phase: load from local memory, multicast-store to symmetric buffer. + This is the building block for both single-tensor and fused multi-tensor all-gathers. + Each thread handles 128-bit (NUMEL_PER_THREAD elements) at a time. + byte_offset locates the tensor within the multicast buffer. -def multimem_all_gather( - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - symm_mem_hdl: _SymmetricMemory, - byte_offset: int = 0, - **kwargs, -) -> torch.Tensor: - """ - Calls a multicast all-gather triton kernel on the given tensor. - Output tensor must be a symmetric memory buffer. - Input tensor can be a regular torch tensor - Arguments: - output_tensor: torch.Tensor - output tensor to be all-gathered into - input_tensor: torch.Tensor - input tensor to be all-gathered from - symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for output_tensor - byte_offset: int - byte offset into the multicast buffer where output_tensor starts - Returns: - torch.Tensor - all-gathered tensor, which is output_tensor + NOTE: When numel is not divisible by (NUMEL_PER_THREAD * WORLD_SIZE), the kernel + rounds up via cdiv and may read/write up to 15 bytes past the logical tensor end. + This is safe because PyTorch's CUDA caching allocator guarantees a minimum block + size of 512 bytes (kMinBlockSize in CUDACachingAllocator.cpp), so small tensors + always have sufficient backing memory. """ - assert HAVE_TRITON, "Triton is required for multimem all-gather." - - config = { - "max_num_blocks": kwargs.get("max_num_blocks", 128), - "num_warps": kwargs.get("num_warps", 32), - "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), - } - # assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - # assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - numel_per_thread = 128 // (input_tensor.element_size() * 8) - - assert ( - output_tensor.numel() % numel_per_thread == 0 - ), "The number of elements must be 128-bit aligned." - - num_threads = triton.cdiv(output_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) - - _multimem_all_gather_kernel[(num_blocks, 1, 1)]( - input_tensor.data_ptr(), - symm_mem_hdl.multicast_ptr, - symm_mem_hdl.signal_pad_ptrs_dev, - numel=output_tensor.numel(), - byte_offset=byte_offset, - BLOCK_SIZE=config["BLOCK_SIZE"], - NUMEL_PER_THREAD=numel_per_thread, - RANK=symm_mem_hdl.rank, - WORLD_SIZE=symm_mem_hdl.world_size, - num_warps=config["num_warps"], - ) - - return output_tensor - - -# ── Fused 3-tensor all-gather ─────────────────────────────────────────────── -# Processes routing_map, probs, and hidden_states in a single kernel launch -# with a single barrier, eliminating 2 kernel launches + 2 barriers. - - -@triton.jit -def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE): - """One all-gather phase: load from local memory, multicast-store to symmetric buffer.""" pid = tl.program_id(axis=0) tid = get_flat_tid() @@ -156,6 +53,9 @@ def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PE offsets = block_start + tid mask = offsets < numel_per_rank + # byte_offset // 8 -> converts byte offset to uint64 offset + # RANK * numel_per_rank -> start of our rank's segment + # * 2 -> each 128-bit pack is 2 uint64s multicast_ptrs = ( multicast_ptr.to(tl.pointer_type(tl.uint64)) + byte_offset // 8 @@ -168,6 +68,26 @@ def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PE block_start += tl.num_programs(axis=0) * BLOCK_SIZE +@triton.jit +def _multimem_all_gather_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + numel, + byte_offset, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """Single-tensor multicast all-gather kernel.""" + _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + sync_threads() + symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, + hasPreviousMemAccess=True, hasSubsequentMemAccess=True) + + @triton.jit def _multimem_all_gather_3_kernel( local_ptr_0, local_ptr_1, local_ptr_2, @@ -186,76 +106,15 @@ def _multimem_all_gather_3_kernel( then synchronizes once, eliminating 2 kernel launches and 2 barriers compared to three separate multimem_all_gather calls. """ - # Phase 1: routing_map _ag_phase(local_ptr_0, multicast_ptr, byte_offset_0, numel_0, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - - # Phase 2: probs _ag_phase(local_ptr_1, multicast_ptr, byte_offset_1, numel_1, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - - # Phase 3: hidden_states _ag_phase(local_ptr_2, multicast_ptr, byte_offset_2, numel_2, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - - # Single barrier for all three tensors sync_threads() - symm_mem_sync( - signal_pad_ptrs, - None, - RANK, - WORLD_SIZE, - hasPreviousMemAccess=True, - hasSubsequentMemAccess=True, - ) - - -def multimem_all_gather_3( - output_0: torch.Tensor, input_0: torch.Tensor, byte_offset_0: int, - output_1: torch.Tensor, input_1: torch.Tensor, byte_offset_1: int, - output_2: torch.Tensor, input_2: torch.Tensor, byte_offset_2: int, - symm_mem_hdl: _SymmetricMemory, - **kwargs, -) -> None: - """ - Fused 3-tensor multicast all-gather. Equivalent to calling multimem_all_gather - three times but with a single kernel launch and a single barrier. - - All tensors must share the same symmetric memory handle and be BF16. - """ - assert HAVE_TRITON, "Triton is required for multimem all-gather." - - config = { - "max_num_blocks": kwargs.get("max_num_blocks", 128), - "num_warps": kwargs.get("num_warps", 32), - "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), - } - - numel_per_thread = 128 // (input_0.element_size() * 8) - - assert output_0.numel() % numel_per_thread == 0, "Tensor 0 must be 128-bit aligned." - assert output_1.numel() % numel_per_thread == 0, "Tensor 1 must be 128-bit aligned." - assert output_2.numel() % numel_per_thread == 0, "Tensor 2 must be 128-bit aligned." - - # Size grid to the largest tensor - max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) - num_threads = triton.cdiv(max_numel // numel_per_thread, symm_mem_hdl.world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) - - _multimem_all_gather_3_kernel[(num_blocks, 1, 1)]( - input_0.data_ptr(), input_1.data_ptr(), input_2.data_ptr(), - symm_mem_hdl.multicast_ptr, - symm_mem_hdl.signal_pad_ptrs_dev, - numel_0=output_0.numel(), byte_offset_0=byte_offset_0, - numel_1=output_1.numel(), byte_offset_1=byte_offset_1, - numel_2=output_2.numel(), byte_offset_2=byte_offset_2, - BLOCK_SIZE=config["BLOCK_SIZE"], - NUMEL_PER_THREAD=numel_per_thread, - RANK=symm_mem_hdl.rank, - WORLD_SIZE=symm_mem_hdl.world_size, - num_warps=config["num_warps"], - ) - + symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, + hasPreviousMemAccess=True, hasSubsequentMemAccess=True) @triton.jit def _multimem_reduce_scatter_kernel( @@ -303,48 +162,118 @@ def _multimem_reduce_scatter_kernel( block_start += tl.num_programs(axis=0) * BLOCK_SIZE +# ── Python wrappers ───────────────────────────────────────────────────────── -def multimem_reduce_scatter( +_DEFAULT_KERNEL_CONFIG = { + "max_num_blocks": 128, + "num_warps": 32, + "BLOCK_SIZE": 1024, +} + + +def _kernel_launch_config(element_size: int, max_numel: int, world_size: int, **kwargs): + """Compute kernel launch config shared by all collective wrappers. + + Args: + element_size: bytes per element (e.g. 2 for bf16). + max_numel: largest tensor numel (determines grid size). + world_size: number of ranks. + + Returns: + (numel_per_thread, num_blocks, config) tuple. + """ + config = {k: kwargs.get(k, v) for k, v in _DEFAULT_KERNEL_CONFIG.items()} + numel_per_thread = 128 // (element_size * 8) + num_threads = triton.cdiv(max_numel // numel_per_thread, world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + return numel_per_thread, num_blocks, config + + +def multimem_all_gather( output_tensor: torch.Tensor, input_tensor: torch.Tensor, symm_mem_hdl: _SymmetricMemory, + byte_offset: int = 0, **kwargs, ) -> torch.Tensor: """ - Calls a multicast reduce-scatter triton kernel on the given tensor. - Input tensor must be a symmetric memory buffer. - Output tensor can be a regular torch tensor - Arguments: - output_tensor: torch.Tensor - output tensor to be reduce-scattered into - input_tensor: torch.Tensor - input tensor to be reduce-scattered from - symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for input_tensor - **kwargs: Additional keyword arguments for kernel configuration: - max_num_blocks (int, optional): The maximum number of blocks to launch. - num_warps (int, optional): The number of warps per block. - BLOCK_SIZE (int, optional): The BLOCK_SIZE parameter for the kernel. - Returns: - torch.Tensor - reduce-scattered tensor, which is output_tensor + Multicast all-gather for a single tensor. + Output tensor must be a symmetric memory buffer. + Input tensor can be a regular torch tensor. """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." - assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." + numel_per_thread, num_blocks, config = _kernel_launch_config( + input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs, + ) + _multimem_all_gather_kernel[(num_blocks, 1, 1)]( + input_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel=output_tensor.numel(), + byte_offset=byte_offset, + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) - config = { - "max_num_blocks": kwargs.get("max_num_blocks", 128), - "num_warps": kwargs.get("num_warps", 32), - "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), - } + return output_tensor - assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - numel_per_thread = 128 // (output_tensor.element_size() * 8) - assert ( - input_tensor.numel() % numel_per_thread == 0 - ), "The number of elements must be 128-bit aligned." +def multimem_all_gather_fused( + output_0: torch.Tensor, input_0: torch.Tensor, byte_offset_0: int, + output_1: torch.Tensor, input_1: torch.Tensor, byte_offset_1: int, + output_2: torch.Tensor, input_2: torch.Tensor, byte_offset_2: int, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> None: + """ + Fused 3-tensor multicast all-gather. Equivalent to calling multimem_all_gather + three times but with a single kernel launch and a single barrier. - num_threads = triton.cdiv(input_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + All tensors must share the same symmetric memory handle. + """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." + max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) + numel_per_thread, num_blocks, config = _kernel_launch_config( + input_0.element_size(), max_numel, symm_mem_hdl.world_size, **kwargs, + ) + _multimem_all_gather_3_kernel[(num_blocks, 1, 1)]( + input_0.data_ptr(), input_1.data_ptr(), input_2.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel_0=output_0.numel(), byte_offset_0=byte_offset_0, + numel_1=output_1.numel(), byte_offset_1=byte_offset_1, + numel_2=output_2.numel(), byte_offset_2=byte_offset_2, + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) + + +def multimem_reduce_scatter( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> torch.Tensor: + """ + Multicast reduce-scatter for a single tensor. + Input tensor must be a symmetric memory buffer. + Output tensor can be a regular torch tensor. + """ + assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." + assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + + numel_per_thread, num_blocks, config = _kernel_launch_config( + output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs, + ) _multimem_reduce_scatter_kernel[(num_blocks, 1, 1)]( output_tensor.data_ptr(), symm_mem_hdl.multicast_ptr, diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index b3fbd15fa31..4a0499b33d8 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -22,7 +22,7 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep from megatron.core.inference.communication.torch_symm_triton import ( - multimem_all_gather_3, + multimem_all_gather_fused, multimem_reduce_scatter, ) @@ -87,79 +87,39 @@ def _maybe_allocate_ag_buffers( Allocate a single symmetric memory buffer for all-gather outputs of routing_map, probs and hidden_states. Returns sliced views for each. - All tensors are gathered from ep_size ranks, so output shapes are: - - routing_map: [local_tokens * ep_size, num_experts] - - probs: [local_tokens * ep_size, num_experts] - - hidden_states: [local_tokens * ep_size, hidden_dim] - Returns dict with: - "handle": symmetric memory handle (or None if unavailable) - - "routing_map": view for routing_map output - - "routing_map_offset": byte offset of routing_map in the symmetric buffer - - "probs": view for probs output - - "probs_offset": byte offset of probs in the symmetric buffer - - "hidden_states": view for hidden_states output - - "hidden_states_offset": byte offset of hidden_states in the symmetric buffer - """ - symm_buffer_mgr = get_global_symmetric_memory_buffer_ep() - if symm_buffer_mgr.symm_mem_hdl is None: - return { - "handle": None, - "routing_map": None, "routing_map_offset": 0, - "probs": None, "probs_offset": 0, - "hidden_states": None, "hidden_states_offset": 0, - } - - # Calculate output shapes after all-gather + - "routing_map" / "routing_map_offset": raw byte view and byte offset + - "probs" / "probs_offset": raw byte view and byte offset + - "hidden_states" / "hidden_states_offset": raw byte view and byte offset + """ + _NONE = { + "handle": None, + "routing_map": None, "routing_map_offset": 0, + "probs": None, "probs_offset": 0, + "hidden_states": None, "hidden_states_offset": 0, + } + local_tokens = probs.size(0) global_tokens = local_tokens * self.ep_size topk = probs.size(-1) hidden_dim = hidden_states.size(-1) - # Calculate bytes needed for each tensor (with 16-byte alignment) - def aligned_bytes(numel, dtype): - elem_size = torch.tensor([], dtype=dtype).element_size() - raw_bytes = numel * elem_size - # Align to 16 bytes for 128-bit access - return ((raw_bytes + 15) // 16) * 16 - - routing_map_bytes = aligned_bytes(global_tokens * topk, routing_map.dtype) - probs_bytes = aligned_bytes(global_tokens * topk, probs.dtype) - hidden_states_bytes = aligned_bytes(global_tokens * hidden_dim, hidden_states.dtype) - total_bytes = routing_map_bytes + probs_bytes + hidden_states_bytes - - # Check if buffer has enough space - if total_bytes > symm_buffer_mgr.symm_buffer.numel(): - return { - "handle": None, - "routing_map": None, "routing_map_offset": 0, - "probs": None, "probs_offset": 0, - "hidden_states": None, "hidden_states_offset": 0, - } - - # Slice the raw buffer and create views, tracking byte offsets - # [routing_map_bytes | probs_bytes | hidden_states_bytes] - # offset=0 offset=rm offset=rm+probs - - raw_buffer = symm_buffer_mgr.symm_buffer - - routing_map_offset = 0 - routing_map_buffer = raw_buffer[routing_map_offset : routing_map_offset + routing_map_bytes] - - probs_offset = routing_map_bytes - probs_buffer = raw_buffer[probs_offset : probs_offset + probs_bytes] + result = get_global_symmetric_memory_buffer_ep().maybe_get_tensors([ + (global_tokens * topk, routing_map.dtype), + (global_tokens * topk, probs.dtype), + (global_tokens * hidden_dim, hidden_states.dtype), + ]) - hidden_states_offset = probs_offset + probs_bytes - hidden_states_buffer = raw_buffer[hidden_states_offset : hidden_states_offset + hidden_states_bytes] + if result["handle"] is None: + return _NONE + (rm_buf, rm_off), (p_buf, p_off), (hs_buf, hs_off) = result["tensors"] return { - "handle": symm_buffer_mgr.symm_mem_hdl, - "routing_map": routing_map_buffer, - "routing_map_offset": routing_map_offset, - "probs": probs_buffer, - "probs_offset": probs_offset, - "hidden_states": hidden_states_buffer, - "hidden_states_offset": hidden_states_offset, + "handle": result["handle"], + "routing_map": rm_buf, "routing_map_offset": rm_off, + "probs": p_buf, "probs_offset": p_off, + "hidden_states": hs_buf, "hidden_states_offset": hs_off, } def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: @@ -177,7 +137,7 @@ def token_dispatch(self, hidden_states, probs): Gathers tokens from all EP ranks using AllGather. Uses latency-optimized NVLS multimem_all_gather for routing_map, probs and hidden_states - on Hopper+ GPUs with BF16. Falls back to NCCL via superclass otherwise. + on Hopper+ GPUs with BF16. Falls back to NCCL otherwise. """ if self.ep_size == 1: return hidden_states, probs @@ -203,7 +163,7 @@ def token_dispatch(self, hidden_states, probs): hidden_dtype = hidden_states.dtype # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors - multimem_all_gather_3( + multimem_all_gather_fused( ag_buffers["routing_map"].view(torch.bfloat16), self.routing_map.view(torch.bfloat16), ag_buffers["routing_map_offset"], diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 636c76f2a84..fc132f0f97e 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -689,6 +689,44 @@ def _allocate(self, numel, dtype) -> torch.Tensor: required_bytes = numel * torch.tensor([], dtype=dtype).element_size() return self.symm_buffer[0:required_bytes].view(dtype).view(numel) + def maybe_get_tensors(self, tensor_specs, alignment=16): + """ + Pack multiple tensors contiguously in the symmetric buffer with alignment. + + Each tensor's starting offset is aligned to `alignment` bytes (default 16 + for 128-bit multimem access). + + Args: + tensor_specs: list of (numel, dtype) tuples. + alignment: byte alignment for each tensor's start offset (default 16). + + Returns: + {"handle": None, "tensors": None} if unavailable or insufficient space. + {"handle": symm_mem_hdl, "tensors": [(raw_byte_view, byte_offset), ...]} + on success, where raw_byte_view is a uint8 slice of the buffer. + """ + _NONE_RESULT = {"handle": None, "tensors": None} + if self.symm_mem_hdl is None: + return _NONE_RESULT + + # Compute aligned byte sizes and running offsets + slices = [] + current_offset = 0 + for numel, dtype in tensor_specs: + nbytes = numel * torch.tensor([], dtype=dtype).element_size() + aligned_nbytes = ((nbytes + alignment - 1) // alignment) * alignment + slices.append((current_offset, nbytes)) + current_offset += aligned_nbytes + + if not self._can_allocate(current_offset, torch.uint8): + return _NONE_RESULT + + tensors = [] + for offset, nbytes in slices: + tensors.append((self.symm_buffer[offset : offset + nbytes], offset)) + + return {"handle": self.symm_mem_hdl, "tensors": tensors} + def maybe_get_tensor(self, tensor_shape, dtype): """ Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape. From 8163d187e32e91169690f3037d8347c88c2c85d9 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 12:29:40 -0800 Subject: [PATCH 52/92] fallback to NCCL as the triton collectives do not work for non 128-bit aligned tensors --- .../moe/token_dispatcher_inference.py | 208 ++---------------- 1 file changed, 21 insertions(+), 187 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 4a0499b33d8..6290b9e8105 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -1,13 +1,10 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. """ -Inference-optimized AlltoAll Token Dispatcher with GPU-resident metadata. +Inference-optimized AllGather Token Dispatcher with GPU-resident metadata. This implementation keeps tokens_per_expert GPU-resident to enable use of torch._grouped_mm without host synchronization. - -Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) -on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. """ import torch @@ -19,15 +16,11 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.tensor_parallel import gather_from_sequence_parallel_region -from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep -from megatron.core.inference.communication.torch_symm_triton import ( - multimem_all_gather_fused, - multimem_reduce_scatter, +from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region, ) -import logging - class InferenceAllGatherTokenDispatcher(MoEAllGatherTokenDispatcher): """ Inference-optimized AllGather token dispatcher. @@ -48,15 +41,6 @@ def __init__( config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None, ) -> None: - """ - Initialize the inference AllGather token dispatcher. - - Args: - num_local_experts: Number of experts on this rank. - local_expert_indices: Global indices of experts on this rank. - config: Transformer configuration. - pg_collection: Process group collection for distributed ops. - """ super().__init__( num_local_experts=num_local_experts, local_expert_indices=local_expert_indices, @@ -65,189 +49,39 @@ def __init__( ) self.topk = config.moe_router_topk - # Cache for NVLS eligibility - self._nvls_eligible = None - - def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: - """ - Check if we can use NVLS (latency-optimized) collectives. - Requirements: BF16 dtype, Hopper+ GPU (SM >= 9). - """ - is_bf16 = x.dtype == torch.bfloat16 - is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 - return is_bf16 and is_hopper_or_newer - - def _maybe_allocate_ag_buffers( - self, - routing_map: torch.Tensor, - probs: torch.Tensor, - hidden_states: torch.Tensor, - ) -> dict: - """ - Allocate a single symmetric memory buffer for all-gather outputs of - routing_map, probs and hidden_states. Returns sliced views for each. - - Returns dict with: - - "handle": symmetric memory handle (or None if unavailable) - - "routing_map" / "routing_map_offset": raw byte view and byte offset - - "probs" / "probs_offset": raw byte view and byte offset - - "hidden_states" / "hidden_states_offset": raw byte view and byte offset - """ - _NONE = { - "handle": None, - "routing_map": None, "routing_map_offset": 0, - "probs": None, "probs_offset": 0, - "hidden_states": None, "hidden_states_offset": 0, - } - - local_tokens = probs.size(0) - global_tokens = local_tokens * self.ep_size - topk = probs.size(-1) - hidden_dim = hidden_states.size(-1) - - result = get_global_symmetric_memory_buffer_ep().maybe_get_tensors([ - (global_tokens * topk, routing_map.dtype), - (global_tokens * topk, probs.dtype), - (global_tokens * hidden_dim, hidden_states.dtype), - ]) - - if result["handle"] is None: - return _NONE - - (rm_buf, rm_off), (p_buf, p_off), (hs_buf, hs_off) = result["tensors"] - return { - "handle": result["handle"], - "routing_map": rm_buf, "routing_map_offset": rm_off, - "probs": p_buf, "probs_offset": p_off, - "hidden_states": hs_buf, "hidden_states_offset": hs_off, - } - - def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: - """ - Allocate symmetric memory buffer for reduce-scatter input. - Input shape matches x (the unpermuted hidden states). - """ - symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( - list(x.size()), dtype=x.dtype - ) - return symm_mem_buffer - def token_dispatch(self, hidden_states, probs): """ - Gathers tokens from all EP ranks using AllGather. - - Uses latency-optimized NVLS multimem_all_gather for routing_map, probs and hidden_states - on Hopper+ GPUs with BF16. Falls back to NCCL otherwise. + Gathers tokens from all EP ranks using NCCL AllGather. """ if self.ep_size == 1: return hidden_states, probs - - # Check NVLS eligibility - nvls_eligible = self._check_nvls_eligibility(hidden_states) - ag_buffers = None - - if nvls_eligible: - ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) - - can_use_nvls = nvls_eligible and ag_buffers["handle"] is not None - - if can_use_nvls: - # Capture shapes for reshaping after all-gather - # Output shape: [local_tokens * ep_size, dim] - local_tokens = probs.size(0) - global_tokens = local_tokens * self.ep_size - topk = probs.size(1) - hidden_dim = hidden_states.size(1) - routing_map_dtype = self.routing_map.dtype - probs_dtype = probs.dtype - hidden_dtype = hidden_states.dtype - - # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors - multimem_all_gather_fused( - ag_buffers["routing_map"].view(torch.bfloat16), - self.routing_map.view(torch.bfloat16), - ag_buffers["routing_map_offset"], - ag_buffers["probs"].view(torch.bfloat16), - probs.view(torch.bfloat16), - ag_buffers["probs_offset"], - ag_buffers["hidden_states"].view(torch.bfloat16), - hidden_states.view(torch.bfloat16), - ag_buffers["hidden_states_offset"], - ag_buffers["handle"], - ) - self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) - probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) - hidden_states = ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) - else: - # Fallback to NCCL for all tensors - with torch.no_grad(): - self.routing_map = gather_from_sequence_parallel_region( - self.routing_map, group=self.tp_ep_group - ) - probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) - hidden_states = gather_from_sequence_parallel_region( - hidden_states, group=self.tp_ep_group - ) + + self.routing_map = gather_from_sequence_parallel_region( + self.routing_map, group=self.tp_ep_group + ) + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) + hidden_states = gather_from_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) return hidden_states, probs - def dispatch_postprocess(self, hidden_states, probs): - """ - No op for flashinfer - """ + """No op for flashinfer.""" raise NotImplementedError - - - def combine_preprocess(self, permuted_expert_outputs): - """ - Reverses token permutation to restore original ordering. - Uses cached permutation and expert_assignments from dispatch_postprocess. - Note: Probability weighting is handled by experts via moe_apply_probs_on_input. - """ - raise NotImplementedError + def combine_preprocess(self, permuted_expert_outputs): + """No op for flashinfer.""" + raise NotImplementedError def token_combine(self, hidden_states): """ - Combines expert outputs using Reduce-Scatter. - - Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs with BF16 - when symmetric memory is available. Falls back to NCCL via superclass otherwise. - - Args: - hidden_states: [global_tokens, hidden_dim] tensor to reduce-scatter - - Returns: - [local_tokens, hidden_dim] tensor after reduce-scatter + Combines expert outputs using NCCL Reduce-Scatter. """ if self.ep_size == 1: return hidden_states - # Check NVLS eligibility and try to allocate symmetric memory - nvls_eligible = self._check_nvls_eligibility(hidden_states) - rs_buffer = None - - if nvls_eligible: - rs_buffer = self._maybe_allocate_rs_buffer(hidden_states) - - can_use_nvls = nvls_eligible and rs_buffer["handle"] is not None - - if can_use_nvls: - # Copy input to symmetric memory for reduce-scatter - rs_buffer["tensor"].copy_(hidden_states) - - # Allocate output tensor - output_shape = list(hidden_states.size()) - output_shape[0] = hidden_states.size(0) // self.ep_size - output = torch.empty( - output_shape, dtype=hidden_states.dtype, device=hidden_states.device - ) - - # Use latency-optimized NVLS reduce-scatter - multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) - return output - else: - # Fallback to NCCL via superclass - return super().token_combine(hidden_states) + return reduce_scatter_to_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) From b89600b93c662fbd07575c8a2394dc1023b753d1 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 12:35:09 -0800 Subject: [PATCH 53/92] remove changes related to symm mem comms --- .../torch_symm_triton/__init__.py | 2 +- .../torch_symm_triton/collectives.py | 266 +++++++----------- megatron/core/parallel_state.py | 42 +-- .../core/tensor_parallel/inference_layers.py | 6 +- megatron/core/utils.py | 38 --- 5 files changed, 120 insertions(+), 234 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index 586e913541e..ca58663d9ec 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter +from .collectives import multimem_all_gather, multimem_reduce_scatter from .fused_collectives import fused_multimem_rs_add_norm_ag diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index eb48dae7d0f..4bc4dbde42b 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -25,41 +25,41 @@ from .multimem_asm import ld_128, st_128 from .utils import get_flat_tid, sync_threads -# ── Triton kernels ───────────────────────────────────────────────────────── @triton.jit -def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE): +def _multimem_all_gather_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + numel, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): """ - Core all-gather phase: load from local memory, multicast-store to symmetric buffer. - This is the building block for both single-tensor and fused multi-tensor all-gathers. - - Each thread handles 128-bit (NUMEL_PER_THREAD elements) at a time. - byte_offset locates the tensor within the multicast buffer. - - NOTE: When numel is not divisible by (NUMEL_PER_THREAD * WORLD_SIZE), the kernel - rounds up via cdiv and may read/write up to 15 bytes past the logical tensor end. - This is safe because PyTorch's CUDA caching allocator guarantees a minimum block - size of 512 bytes (kMinBlockSize in CUDACachingAllocator.cpp), so small tensors - always have sufficient backing memory. + Triton kernel to perform multicast all-gather over nvlink using multimem instructions. """ + # an all-gather is simply a multicast store operation + # we only need a barrier at the end to ensure visibility of writes + pid = tl.program_id(axis=0) tid = get_flat_tid() - numel_128 = numel // NUMEL_PER_THREAD - numel_per_rank = tl.cdiv(numel_128, WORLD_SIZE) + # From this point on, we pretend each element is 128-bit + numel = numel // NUMEL_PER_THREAD + numel_per_rank = tl.cdiv(numel, WORLD_SIZE) block_start = pid * BLOCK_SIZE while block_start < numel_per_rank: offsets = block_start + tid mask = offsets < numel_per_rank - # byte_offset // 8 -> converts byte offset to uint64 offset - # RANK * numel_per_rank -> start of our rank's segment - # * 2 -> each 128-bit pack is 2 uint64s + # Each pointer points to a 128-bit bit pack + # RANK * numel_per_rank -> brings us to the start of our rank's segment + # offsets -> brings us to the right offset within our rank's segment multicast_ptrs = ( - multicast_ptr.to(tl.pointer_type(tl.uint64)) - + byte_offset // 8 - + (RANK * numel_per_rank + offsets) * 2 + multicast_ptr.to(tl.pointer_type(tl.uint64)) + (RANK * numel_per_rank + offsets) * 2 ) local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) @@ -67,54 +67,66 @@ def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PE block_start += tl.num_programs(axis=0) * BLOCK_SIZE - -@triton.jit -def _multimem_all_gather_kernel( - local_ptr, - multicast_ptr, - signal_pad_ptrs, - numel, - byte_offset, - BLOCK_SIZE: tl.constexpr, - NUMEL_PER_THREAD: tl.constexpr, - RANK: tl.constexpr, - WORLD_SIZE: tl.constexpr, -): - """Single-tensor multicast all-gather kernel.""" - _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) sync_threads() - symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, - hasPreviousMemAccess=True, hasSubsequentMemAccess=True) + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) -@triton.jit -def _multimem_all_gather_3_kernel( - local_ptr_0, local_ptr_1, local_ptr_2, - multicast_ptr, - signal_pad_ptrs, - numel_0, byte_offset_0, - numel_1, byte_offset_1, - numel_2, byte_offset_2, - BLOCK_SIZE: tl.constexpr, - NUMEL_PER_THREAD: tl.constexpr, - RANK: tl.constexpr, - WORLD_SIZE: tl.constexpr, -): +def multimem_all_gather( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> torch.Tensor: """ - Fused 3-tensor multicast all-gather. Processes three tensors in sequence - then synchronizes once, eliminating 2 kernel launches and 2 barriers - compared to three separate multimem_all_gather calls. + Calls a multicast all-gather triton kernel on the given tensor. + Output tensor must be a symmetric memory buffer. + Input tensor can be a regular torch tensor + Arguments: + output_tensor: torch.Tensor - output tensor to be all-gathered into + input_tensor: torch.Tensor - input tensor to be all-gathered from + symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for output_tensor + Returns: + torch.Tensor - all-gathered tensor, which is output_tensor """ - _ag_phase(local_ptr_0, multicast_ptr, byte_offset_0, numel_0, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - _ag_phase(local_ptr_1, multicast_ptr, byte_offset_1, numel_1, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - _ag_phase(local_ptr_2, multicast_ptr, byte_offset_2, numel_2, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - sync_threads() - symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, - hasPreviousMemAccess=True, hasSubsequentMemAccess=True) + assert HAVE_TRITON, "Triton is required for multimem all-gather." + + config = { + "max_num_blocks": kwargs.get("max_num_blocks", 24), + "num_warps": kwargs.get("num_warps", 32), + "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), + } + assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + numel_per_thread = 128 // (input_tensor.element_size() * 8) + + assert ( + output_tensor.numel() % numel_per_thread == 0 + ), "The number of elements must be 128-bit aligned." + + num_threads = triton.cdiv(output_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + + _multimem_all_gather_kernel[(num_blocks, 1, 1)]( + input_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel=output_tensor.numel(), + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) + + return output_tensor + @triton.jit def _multimem_reduce_scatter_kernel( @@ -162,99 +174,6 @@ def _multimem_reduce_scatter_kernel( block_start += tl.num_programs(axis=0) * BLOCK_SIZE -# ── Python wrappers ───────────────────────────────────────────────────────── - -_DEFAULT_KERNEL_CONFIG = { - "max_num_blocks": 128, - "num_warps": 32, - "BLOCK_SIZE": 1024, -} - - -def _kernel_launch_config(element_size: int, max_numel: int, world_size: int, **kwargs): - """Compute kernel launch config shared by all collective wrappers. - - Args: - element_size: bytes per element (e.g. 2 for bf16). - max_numel: largest tensor numel (determines grid size). - world_size: number of ranks. - - Returns: - (numel_per_thread, num_blocks, config) tuple. - """ - config = {k: kwargs.get(k, v) for k, v in _DEFAULT_KERNEL_CONFIG.items()} - numel_per_thread = 128 // (element_size * 8) - num_threads = triton.cdiv(max_numel // numel_per_thread, world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) - return numel_per_thread, num_blocks, config - - -def multimem_all_gather( - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - symm_mem_hdl: _SymmetricMemory, - byte_offset: int = 0, - **kwargs, -) -> torch.Tensor: - """ - Multicast all-gather for a single tensor. - Output tensor must be a symmetric memory buffer. - Input tensor can be a regular torch tensor. - """ - assert HAVE_TRITON, "Triton is required for multimem all-gather." - - numel_per_thread, num_blocks, config = _kernel_launch_config( - input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs, - ) - _multimem_all_gather_kernel[(num_blocks, 1, 1)]( - input_tensor.data_ptr(), - symm_mem_hdl.multicast_ptr, - symm_mem_hdl.signal_pad_ptrs_dev, - numel=output_tensor.numel(), - byte_offset=byte_offset, - BLOCK_SIZE=config["BLOCK_SIZE"], - NUMEL_PER_THREAD=numel_per_thread, - RANK=symm_mem_hdl.rank, - WORLD_SIZE=symm_mem_hdl.world_size, - num_warps=config["num_warps"], - ) - - return output_tensor - - -def multimem_all_gather_fused( - output_0: torch.Tensor, input_0: torch.Tensor, byte_offset_0: int, - output_1: torch.Tensor, input_1: torch.Tensor, byte_offset_1: int, - output_2: torch.Tensor, input_2: torch.Tensor, byte_offset_2: int, - symm_mem_hdl: _SymmetricMemory, - **kwargs, -) -> None: - """ - Fused 3-tensor multicast all-gather. Equivalent to calling multimem_all_gather - three times but with a single kernel launch and a single barrier. - - All tensors must share the same symmetric memory handle. - """ - assert HAVE_TRITON, "Triton is required for multimem all-gather." - - max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) - numel_per_thread, num_blocks, config = _kernel_launch_config( - input_0.element_size(), max_numel, symm_mem_hdl.world_size, **kwargs, - ) - _multimem_all_gather_3_kernel[(num_blocks, 1, 1)]( - input_0.data_ptr(), input_1.data_ptr(), input_2.data_ptr(), - symm_mem_hdl.multicast_ptr, - symm_mem_hdl.signal_pad_ptrs_dev, - numel_0=output_0.numel(), byte_offset_0=byte_offset_0, - numel_1=output_1.numel(), byte_offset_1=byte_offset_1, - numel_2=output_2.numel(), byte_offset_2=byte_offset_2, - BLOCK_SIZE=config["BLOCK_SIZE"], - NUMEL_PER_THREAD=numel_per_thread, - RANK=symm_mem_hdl.rank, - WORLD_SIZE=symm_mem_hdl.world_size, - num_warps=config["num_warps"], - ) - def multimem_reduce_scatter( output_tensor: torch.Tensor, @@ -263,17 +182,40 @@ def multimem_reduce_scatter( **kwargs, ) -> torch.Tensor: """ - Multicast reduce-scatter for a single tensor. + Calls a multicast reduce-scatter triton kernel on the given tensor. Input tensor must be a symmetric memory buffer. - Output tensor can be a regular torch tensor. + Output tensor can be a regular torch tensor + Arguments: + output_tensor: torch.Tensor - output tensor to be reduce-scattered into + input_tensor: torch.Tensor - input tensor to be reduce-scattered from + symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for input_tensor + **kwargs: Additional keyword arguments for kernel configuration: + max_num_blocks (int, optional): The maximum number of blocks to launch. + num_warps (int, optional): The number of warps per block. + BLOCK_SIZE (int, optional): The BLOCK_SIZE parameter for the kernel. + Returns: + torch.Tensor - reduce-scattered tensor, which is output_tensor """ + assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." + + config = { + "max_num_blocks": kwargs.get("max_num_blocks", 24), + "num_warps": kwargs.get("num_warps", 32), + "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), + } + assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + numel_per_thread = 128 // (output_tensor.element_size() * 8) + + assert ( + input_tensor.numel() % numel_per_thread == 0 + ), "The number of elements must be 128-bit aligned." + + num_threads = triton.cdiv(input_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) - numel_per_thread, num_blocks, config = _kernel_launch_config( - output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs, - ) _multimem_reduce_scatter_kernel[(num_blocks, 1, 1)]( output_tensor.data_ptr(), symm_mem_hdl.multicast_ptr, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index b571a357fad..087cbe7e152 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -140,9 +140,8 @@ # Memory buffers to avoid dynamic memory allocation _GLOBAL_MEMORY_BUFFER = None -# Global symmetric memory buffers for inference -_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None -_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None +# Global symmetric memory buffer for inference +_GLOBAL_SYMMETRIC_MEMORY_BUFFER = None # List of all process groups # Used for updating the timeout for all process groups @@ -2017,20 +2016,14 @@ def _set_global_memory_buffer(): def _set_global_symmetric_memory_buffer(): """Initialize global buffer.""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP - assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is None, "global symmetric memory buffer for TP is already initialized" - assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is None, "global symmetric memory buffer for EP is already initialized" + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER + assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER is None, "global memory buffer is already initialized" - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = GlobalSymmetricMemoryBuffer( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER = GlobalSymmetricMemoryBuffer( size_in_mb=256, # todo: set from an argument? process_group=get_tensor_model_parallel_group(), ) - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = GlobalSymmetricMemoryBuffer( - size_in_mb=256, # todo: set from an argument? - process_group=get_expert_model_parallel_group(), - ) - def get_global_memory_buffer(): """Return the global GlobalMemoryBuffer object""" @@ -2038,19 +2031,12 @@ def get_global_memory_buffer(): return _GLOBAL_MEMORY_BUFFER -def get_global_symmetric_memory_buffer_tp(): - """Return the global GlobalSymmetricMemoryBuffer object""" - assert ( - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is not None - ), "global symmetric memory buffer is not initialized" - return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP - -def get_global_symmetric_memory_buffer_ep(): +def get_global_symmetric_memory_buffer(): """Return the global GlobalSymmetricMemoryBuffer object""" assert ( - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is not None + _GLOBAL_SYMMETRIC_MEMORY_BUFFER is not None ), "global symmetric memory buffer is not initialized" - return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER def destroy_global_memory_buffer(): @@ -2061,9 +2047,8 @@ def destroy_global_memory_buffer(): def destroy_global_symmetric_memory_buffer(): """Sets the global symmetric memory buffer to None""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER + _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None def get_all_ranks(): @@ -2144,11 +2129,8 @@ def destroy_model_parallel(): global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None - - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP - _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER + _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None global _DATA_PARALLEL_GROUP_GLOO if ( diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index fcf882b6818..0addc64a65f 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -16,7 +16,7 @@ from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor from megatron.core.inference.quantization.utils import mm_mxfp8 from megatron.core.model_parallel_config import ModelParallelConfig -from megatron.core.parallel_state import get_global_symmetric_memory_buffer_tp +from megatron.core.parallel_state import get_global_symmetric_memory_buffer from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -120,7 +120,7 @@ def _maybe_allocate_symmetric_buffer(self, x: torch.Tensor): """ symm_mem_buffer_dims = list(x.size()) symm_mem_buffer_dims[0] *= self.tp_size - symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) return symm_mem_buffer @@ -245,7 +245,7 @@ def _matmul_reduce_scatter(self, x, residual=None): # Remove batch dimension for FlashInfer mxfp8 del symm_mem_buffer_dims[1] symm_mem_buffer_dims[-1] = self.weight.size(0) - symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None diff --git a/megatron/core/utils.py b/megatron/core/utils.py index ed31b77ba04..c0533bd1fab 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -728,44 +728,6 @@ def _allocate(self, numel, dtype) -> torch.Tensor: required_bytes = numel * torch.tensor([], dtype=dtype).element_size() return self.symm_buffer[0:required_bytes].view(dtype).view(numel) - def maybe_get_tensors(self, tensor_specs, alignment=16): - """ - Pack multiple tensors contiguously in the symmetric buffer with alignment. - - Each tensor's starting offset is aligned to `alignment` bytes (default 16 - for 128-bit multimem access). - - Args: - tensor_specs: list of (numel, dtype) tuples. - alignment: byte alignment for each tensor's start offset (default 16). - - Returns: - {"handle": None, "tensors": None} if unavailable or insufficient space. - {"handle": symm_mem_hdl, "tensors": [(raw_byte_view, byte_offset), ...]} - on success, where raw_byte_view is a uint8 slice of the buffer. - """ - _NONE_RESULT = {"handle": None, "tensors": None} - if self.symm_mem_hdl is None: - return _NONE_RESULT - - # Compute aligned byte sizes and running offsets - slices = [] - current_offset = 0 - for numel, dtype in tensor_specs: - nbytes = numel * torch.tensor([], dtype=dtype).element_size() - aligned_nbytes = ((nbytes + alignment - 1) // alignment) * alignment - slices.append((current_offset, nbytes)) - current_offset += aligned_nbytes - - if not self._can_allocate(current_offset, torch.uint8): - return _NONE_RESULT - - tensors = [] - for offset, nbytes in slices: - tensors.append((self.symm_buffer[offset : offset + nbytes], offset)) - - return {"handle": self.symm_mem_hdl, "tensors": tensors} - def maybe_get_tensor(self, tensor_shape, dtype): """ Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape. From ba396b1ccb672b1c2bf12822d8bc4f3a233c1b2b Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 12:48:51 -0800 Subject: [PATCH 54/92] refactor --- .../text_generation_controller.py | 4 ++-- megatron/core/inference/utils.py | 4 ++-- megatron/core/transformer/moe/moe_layer.py | 14 +++++++------- megatron/core/transformer/moe/router.py | 8 ++++---- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 5d8e00d67bf..6e36dac1f2a 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -25,7 +25,7 @@ AbstractModelInferenceWrapper, ) from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding, set_is_cuda_graphed_iteration_for_ep_inference +from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding, set_is_inference_cuda_graphed_iteration_for_ep_inference from megatron.core.models.multimodal.llava_model import LLaVAModel from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope @@ -544,7 +544,7 @@ def _dynamic_step_context_init( set_decode_expert_padding(unwrapped_model, False) if is_inference_optimized and model_config.expert_model_parallel_size > 1: - set_is_cuda_graphed_iteration_for_ep_inference(unwrapped_model, context.using_cuda_graph_this_step()) + set_is_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model, context.using_cuda_graph_this_step()) # initialize symmetric memory if needed if model_config.transformer_impl == "inference_optimized": diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index 5a2939decc7..42ac9577868 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -131,7 +131,7 @@ def set_decode_expert_padding(model, set_to: bool = False, capacity_factor: int router.config.moe_expert_capacity_factor = capacity_factor router.config.moe_pad_expert_input_to_capacity = bool(set_to) -def set_is_cuda_graphed_iteration_for_ep_inference(model, set_to: bool): +def set_is_inference_cuda_graphed_iteration_for_ep_inference(model, set_to: bool): """ Toggle CUDA graph compatibility for expert parallel inference. This sets a boolean flag in all MoELayers to indicate whether @@ -145,7 +145,7 @@ def set_is_cuda_graphed_iteration_for_ep_inference(model, set_to: bool): _init_moe_expert_cache(model) for moe_layer in moe_layer_cache: - moe_layer.set_is_cuda_graphed_iteration(set_to) + moe_layer.set_is_inference_cuda_graphed_iteration(set_to) def tensor_swap(x, src_idxs, dst_idxs): """ diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index e892db9c148..57c224bdf99 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -270,7 +270,7 @@ def _setup_inference_mode(self, pg_collection): f"Inference-optimized MoE requires 'alltoall' dispatcher, " f"got '{self.config.moe_token_dispatcher_type}'" ) - self.is_cuda_graphed_iteration = False + self.is_inference_cuda_graphed_iteration = False self._inference_token_dispatcher = InferenceAllGatherTokenDispatcher( self.num_local_experts, self.local_expert_indices, @@ -278,11 +278,11 @@ def _setup_inference_mode(self, pg_collection): pg_collection=pg_collection, ) - def set_is_cuda_graphed_iteration(self, set_to: bool): + def set_is_inference_cuda_graphed_iteration(self, set_to: bool): """Toggle CUDA-graphed iteration mode on this layer and its router.""" - self.is_cuda_graphed_iteration = set_to - if hasattr(self.router, 'set_is_cuda_graphed_iteration'): - self.router.set_is_cuda_graphed_iteration(set_to) + self.is_inference_cuda_graphed_iteration = set_to + if hasattr(self.router, 'set_is_inference_cuda_graphed_iteration'): + self.router.set_is_inference_cuda_graphed_iteration(set_to) def _activate_inference_token_dispatcher(self): """Swap in the inference-optimized token dispatcher.""" @@ -372,7 +372,7 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso for each expert. It then passes the tokens through the local experts. The output from the experts is preprocessed for the combine step. """ - if not self.training and self.is_cuda_graphed_iteration: + if not self.training and self.is_inference_cuda_graphed_iteration: return self._fused_experts_compute(hidden_states, probs) dispatched_input, tokens_per_expert, permuted_probs = ( @@ -483,7 +483,7 @@ def forward( # Swap in inference-optimized dispatcher for CUDA-graphed iterations _use_inference_dispatcher = ( not self.training - and self.is_cuda_graphed_iteration + and self.is_inference_cuda_graphed_iteration and self._inference_token_dispatcher is not None ) if _use_inference_dispatcher: diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index ac751988e4d..26c730f8bf9 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -744,10 +744,10 @@ def __init__( super().__init__(config=config, pg_collection=pg_collection) - self.is_cuda_graphed_iteration = False + self.is_inference_cuda_graphed_iteration = False - def set_is_cuda_graphed_iteration(self, set_to: bool): - self.is_cuda_graphed_iteration = set_to + def set_is_inference_cuda_graphed_iteration(self, set_to: bool): + self.is_inference_cuda_graphed_iteration = set_to @torch.compile() def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): @@ -785,7 +785,7 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No - top_indices: Selected expert indices [num_tokens, topk] """ - if self.training or not self.is_cuda_graphed_iteration: + if self.training or not self.is_inference_cuda_graphed_iteration: return super().forward(input, padding_mask) return self._forward(input, padding_mask) \ No newline at end of file From 462ed8a5895eb4378e52b40e4099979d65feec69 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 13:02:46 -0800 Subject: [PATCH 55/92] more cleanup and add warnings if flashinfer-jit and cubin are not installed --- megatron/core/transformer/moe/moe_layer.py | 36 +++++++++++++++++----- megatron/core/transformer/moe/router.py | 1 - 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 57c224bdf99..ed783a7a5cc 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Protocol, Union @@ -28,6 +29,25 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.typed_torch import apply_module from megatron.core.utils import internal_api +from megatron.core.transformer.moe.token_dispatcher_inference import ( + InferenceAllGatherTokenDispatcher, +) +try: + import flashinfer.fused_moe as fused_moe + from flashinfer.fused_moe.core import ActivationType + HAVE_FLASHINFER = True +except ImportError: + HAVE_FLASHINFER = False + +if HAVE_FLASHINFER: + try: + import flashinfer_cubin + import flashinfer_jit_cache + HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = True + except ImportError: + HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = False + +from megatron.core.activations import squared_relu try: import transformer_engine as te # pylint: disable=unused-import @@ -248,8 +268,15 @@ def __init__( # Inference-optimized mode setup if config.transformer_impl == "inference_optimized": + assert HAVE_FLASHINFER, "flashinfer-python is required for inference-optimized MoE implementation." + if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: + warnings.warn( + "flashinfer-cubin and/or flashinfer-jit-cache not found. " + "The FlashInfer cutlass kernel will be JIT compiled, which may take a long time." + ) self._setup_inference_mode(pg_collection) + # Cudagraph tensor store for resuming the forward pass from the end of the cudagraph. self.cudagraph_tensor_store = MoECudaGraphTensorStore() self.fwd_execution_map = ["route", "expert_compute", "postprocess"] @@ -262,10 +289,7 @@ def _setup_inference_mode(self, pg_collection): Creates an InferenceAllGatherTokenDispatcher alongside the standard dispatcher, which is swapped in during CUDA-graphed forward passes. """ - from megatron.core.transformer.moe.token_dispatcher_inference import ( - InferenceAllGatherTokenDispatcher, - ) - + assert self.config.moe_token_dispatcher_type == "alltoall", ( f"Inference-optimized MoE requires 'alltoall' dispatcher, " f"got '{self.config.moe_token_dispatcher_type}'" @@ -386,10 +410,6 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso def _fused_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor): """FlashInfer fused MoE kernel for CUDA-graphed inference iterations.""" - import flashinfer.fused_moe as fused_moe - from flashinfer.fused_moe.core import ActivationType - - from megatron.core.activations import squared_relu assert not self.config.gated_linear_unit, ( "FlashInfer MoE kernel currently only supports non-gated activations. " diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 26c730f8bf9..d21f895f229 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -23,7 +23,6 @@ ) from megatron.core.transformer.moe.router_replay import RouterReplay from megatron.core.transformer.transformer_config import TransformerConfig -import logging class Router(ABC, MegatronModule): """Base Router class""" From e3311a0a2be1cbd88bc8571bc649b3b700159452 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 14:26:37 -0800 Subject: [PATCH 56/92] more refactor --- megatron/core/transformer/moe/experts.py | 81 ++++++++++++++----- megatron/core/transformer/moe/moe_layer.py | 55 ++----------- .../moe/token_dispatcher_inference.py | 10 +-- 3 files changed, 73 insertions(+), 73 deletions(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index e6cf7f3fb6c..461bd8e01ea 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -67,6 +67,14 @@ HAVE_TE = False +try: + import flashinfer.fused_moe as fused_moe + from flashinfer.fused_moe.core import ActivationType + + HAVE_FLASHINFER = True +except ImportError: + HAVE_FLASHINFER = False + logger = logging.getLogger(__name__) @@ -913,11 +921,13 @@ def backward_dw(self): class InferenceGroupedMLP(TEGroupedMLP): - """Inference-optimized GroupedMLP using torch._grouped_mm with GPU-resident offsets. + """Inference-optimized GroupedMLP with GPU-resident offsets. Inherits from TEGroupedMLP to reuse weight initialization and checkpoint compatibility. - Overrides forward() to use torch._grouped_mm instead of TE's grouped linear, - keeping tokens_per_expert on GPU to avoid host synchronization. + Supports three forward paths: + - Training: delegates to parent TEGroupedMLP + - Inference + CUDA graphed: FlashInfer cutlass_fused_moe (fused permute + GEMM) + - Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets """ def __init__( @@ -940,8 +950,11 @@ def __init__( # torch._grouped_mm expects shape [num_experts, out_features, in_features] self._build_concatenated_weights() - # Register hook to rebuild concatenated weights after load_state_dict - # self._register_load_state_dict_post_hook(self._rebuild_weights_hook) + self.is_inference_cuda_graphed_iteration = False + + def set_is_inference_cuda_graphed_iteration(self, set_to: bool): + """Toggle CUDA-graphed iteration mode.""" + self.is_inference_cuda_graphed_iteration = set_to def _build_concatenated_weights(self): """Create big contiguous weight tensors with per-expert views for checkpoint compatibility. @@ -992,22 +1005,24 @@ def _build_concatenated_weights(self): self.register_buffer('_fc1_weight', _fc1_weight, persistent=False) self.register_buffer('_fc2_weight', _fc2_weight, persistent=False) - def forward( - self, - permuted_local_hidden_states: torch.Tensor, - tokens_per_expert: torch.Tensor, - permuted_probs: torch.Tensor, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Forward pass using torch._grouped_mm with GPU-resident offsets. - - Args: - permuted_local_hidden_states: [total_tokens, hidden_size] input tensor - tokens_per_expert: [num_local_experts] GPU tensor with token counts per expert - permuted_probs: [total_tokens] routing probabilities - - Returns: - Tuple of (output, None) for interface compatibility - """ + def _flashinfer_forward(self, hidden_states, routing_map, probs): + """FlashInfer fused MoE kernel for CUDA-graphed inference iterations.""" + assert HAVE_FLASHINFER, "flashinfer-python is required for FlashInfer forward path." + output = fused_moe.cutlass_fused_moe( + hidden_states, + routing_map.to(torch.int), + probs.float(), + self._fc1_weight, + self._fc2_weight, + hidden_states.dtype, + quant_scales=None, + activation_type=ActivationType.Relu2, + ep_size=self.ep_group.size(), + ep_rank=self.ep_group.rank(), + )[0] + return output, None + + def _torch_grouped_mm_forward(self, permuted_local_hidden_states, tokens_per_expert, permuted_probs): permuted_probs = permuted_probs.unsqueeze(-1) #assert tokens_per_expert.is_cuda, "tokens_per_expert must be on GPU" if not tokens_per_expert.is_cuda: @@ -1104,6 +1119,30 @@ def glu(x): return fc2_output, None + def forward( + self, + permuted_local_hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + permuted_probs: torch.Tensor, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass with three modes: + - Training: delegates to parent TEGroupedMLP + - Inference + CUDA graphed: FlashInfer cutlass_fused_moe + - Inference + eager: torch._grouped_mm with GPU-resident offsets + """ + if self.training: + return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) + + elif self.is_inference_cuda_graphed_iteration: + return self._flashinfer_forward( + permuted_local_hidden_states, tokens_per_expert, permuted_probs + ) + + else: + return self._torch_grouped_mm_forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) + + + class SequentialMLP(MegatronModule): """An implementation of the Experts layer using a sequence of MLP layers. diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index ed783a7a5cc..81681ee3963 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -33,22 +33,19 @@ InferenceAllGatherTokenDispatcher, ) try: - import flashinfer.fused_moe as fused_moe - from flashinfer.fused_moe.core import ActivationType + import flashinfer HAVE_FLASHINFER = True except ImportError: HAVE_FLASHINFER = False if HAVE_FLASHINFER: - try: - import flashinfer_cubin - import flashinfer_jit_cache + try: + import flashinfer_cubin + import flashinfer_jit_cache HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = True except ImportError: HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = False -from megatron.core.activations import squared_relu - try: import transformer_engine as te # pylint: disable=unused-import @@ -303,10 +300,10 @@ def _setup_inference_mode(self, pg_collection): ) def set_is_inference_cuda_graphed_iteration(self, set_to: bool): - """Toggle CUDA-graphed iteration mode on this layer and its router.""" + """Toggle CUDA-graphed iteration mode on this layer, its router, and its experts.""" self.is_inference_cuda_graphed_iteration = set_to - if hasattr(self.router, 'set_is_inference_cuda_graphed_iteration'): - self.router.set_is_inference_cuda_graphed_iteration(set_to) + self.router.set_is_inference_cuda_graphed_iteration(set_to) + self.experts.set_is_inference_cuda_graphed_iteration(set_to) def _activate_inference_token_dispatcher(self): """Swap in the inference-optimized token dispatcher.""" @@ -396,9 +393,6 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso for each expert. It then passes the tokens through the local experts. The output from the experts is preprocessed for the combine step. """ - if not self.training and self.is_inference_cuda_graphed_iteration: - return self._fused_experts_compute(hidden_states, probs) - dispatched_input, tokens_per_expert, permuted_probs = ( self.token_dispatcher.dispatch_postprocess(hidden_states, probs) ) @@ -408,39 +402,6 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso return output, mlp_bias - def _fused_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tensor): - """FlashInfer fused MoE kernel for CUDA-graphed inference iterations.""" - - assert not self.config.gated_linear_unit, ( - "FlashInfer MoE kernel currently only supports non-gated activations. " - f"Got gated_linear_unit={self.config.gated_linear_unit}" - ) - assert self.config.activation_func == squared_relu, ( - "FlashInfer MoE kernel currently only supports squared_relu activation. " - f"Got activation_func={self.config.activation_func}" - ) - - w1 = self.experts._fc1_weight - w2 = self.experts._fc2_weight - selected_experts = self.token_dispatcher.routing_map - ep_size = utils.get_pg_size(self.ep_group) - ep_rank = utils.get_pg_rank(self.ep_group) - - output = fused_moe.cutlass_fused_moe( - hidden_states, - selected_experts.to(torch.int), - probs.float(), - w1, - w2, - hidden_states.dtype, - quant_scales=None, - activation_type=ActivationType.Relu2, - ep_size=ep_size, - ep_rank=ep_rank, - )[0] - - return output, None - def combine(self, output: torch.Tensor): """Combines expert outputs via communication and adds shared expert output. @@ -500,7 +461,7 @@ def forward( if padding_mask is not None: padding_mask = padding_mask.transpose(0, 1).bool() - # Swap in inference-optimized dispatcher for CUDA-graphed iterations + # Swap in inference-optimized dispatcher for CUDA-graphed inference iterations _use_inference_dispatcher = ( not self.training and self.is_inference_cuda_graphed_iteration diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 6290b9e8105..24c4a672327 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -67,12 +67,12 @@ def token_dispatch(self, hidden_states, probs): return hidden_states, probs def dispatch_postprocess(self, hidden_states, probs): - """No op for flashinfer.""" - raise NotImplementedError + """Pass-through: returns unpermuted inputs and routing_map for InferenceGroupedMLP.""" + return hidden_states, self.routing_map, probs - def combine_preprocess(self, permuted_expert_outputs): - """No op for flashinfer.""" - raise NotImplementedError + def combine_preprocess(self, expert_output): + """Pass-through: InferenceGroupedMLP already produces unpermuted output.""" + return expert_output def token_combine(self, hidden_states): """ From afb807b1352597c61b3cd98fb37075f45c875ac3 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 17:37:26 -0800 Subject: [PATCH 57/92] make qwen3 work without CGs --- megatron/core/models/gpt/gpt_layer_specs.py | 14 +------------- megatron/core/transformer/moe/moe_layer.py | 9 ++++++--- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 5fc9fc082b0..9e711a92fda 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -149,21 +149,9 @@ def get_gpt_layer_with_inference_submodules( L2Norm if qk_l2_norm else (qk_norm if qk_layernorm else IdentityOp) ), ), - self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, - mlp=mlp, - mlp_bda=get_bias_dropout_add, - sharded_state_dict_keys_map={ - "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", - "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", - "mlp.1.basic_ops.0.weight": "mlp.linear_fc1.weight", - "mlp.1.basic_ops.1.bias": "mlp.linear_fc1.bias", - "mlp.3.basic_ops.0.weight": "mlp.linear_fc2.weight", - "mlp.3.basic_ops.1.bias": "mlp.linear_fc2.bias", - }, ), self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, sharded_state_dict_keys_map={ diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 81681ee3963..42f06ef9850 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -302,8 +302,10 @@ def _setup_inference_mode(self, pg_collection): def set_is_inference_cuda_graphed_iteration(self, set_to: bool): """Toggle CUDA-graphed iteration mode on this layer, its router, and its experts.""" self.is_inference_cuda_graphed_iteration = set_to - self.router.set_is_inference_cuda_graphed_iteration(set_to) - self.experts.set_is_inference_cuda_graphed_iteration(set_to) + if hasattr(self.router, "set_is_inference_cuda_graphed_iteration"): + self.router.set_is_inference_cuda_graphed_iteration(set_to) + if hasattr(self.experts, "set_is_inference_cuda_graphed_iteration"): + self.experts.set_is_inference_cuda_graphed_iteration(set_to) def _activate_inference_token_dispatcher(self): """Swap in the inference-optimized token dispatcher.""" @@ -463,7 +465,8 @@ def forward( # Swap in inference-optimized dispatcher for CUDA-graphed inference iterations _use_inference_dispatcher = ( - not self.training + self.config.transformer_impl == "inference_optimized" + and not self.training and self.is_inference_cuda_graphed_iteration and self._inference_token_dispatcher is not None ) From 51c383d6a9e355c56002994b19e750d8fca1e390 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 17:38:20 -0800 Subject: [PATCH 58/92] remove comment --- gpt_builders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gpt_builders.py b/gpt_builders.py index bb273211080..06101d300b2 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -53,7 +53,6 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ ) ) elif args.num_experts: - #assert not (config.transformer_impl == "inference_optimized") # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config, From c761a0d4625a4e7414cc0e3d9ec0a4c9e2299309 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 17:46:39 -0800 Subject: [PATCH 59/92] refactor --- .../core/transformer/transformer_config.py | 33 ++++++------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 356db9de26c..f1bd5758202 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1088,31 +1088,18 @@ def __post_init__(self): self.expert_tensor_parallel_size > 1 ): raise ValueError( - "Inference-optimized MoE layers currently only support data parallelism " - "(expert_model_parallel_size=1 and expert_tensor_parallel_size=1). " - "Multi-GPU support is planned for future work." + "Inference-optimized MoE layers does not support expert tensor parallelism." ) - if self.transformer_impl == "inference_optimized" and ( - self.moe_expert_capacity_factor is not None - or self.moe_router_padding_for_quantization - ): - raise ValueError( - "Inference-optimized MoE layers only support dropless MoE " - "(moe_expert_capacity_factor=None and moe_router_padding_for_quantization=False). " - ) - - # if self.transformer_impl == "inference_optimized" and self.num_moe_experts is not None: - # if not self.moe_permute_fusion: - # raise ValueError( - # "Inference-optimized MoE layers require moe_permute_fusion=True " - # "to use TE fused kernels that support GPU-resident metadata." - # ) - # # if not self.moe_router_fusion: - # # raise ValueError( - # # "Inference-optimized MoE layers require moe_router_fusion=True " - # # "to use TE fused router kernels." - # # ) + if self.transformer_impl == "inference_optimized": + if self.moe_expert_capacity_factor is not None: + raise ValueError( + "Inference-optimized MoE layers only support dropless MoE " + ) + if self.moe_router_padding_for_quantization: + raise ValueError( + "Inference-optimized MoE layers do not support padded routing map for quantization." + ) if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError("num_moe_experts must be non-negative.") From d9f17121407f7fe5fa44cba3a0fa9e4a072717dd Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 23:58:13 -0800 Subject: [PATCH 60/92] Revert "fallback to NCCL as the triton collectives do not work for non 128-bit aligned tensors" This reverts commit 8163d187e32e91169690f3037d8347c88c2c85d9. --- .../moe/token_dispatcher_inference.py | 192 ++++++++++++++++-- 1 file changed, 175 insertions(+), 17 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 24c4a672327..aace1827be8 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -1,10 +1,13 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. """ -Inference-optimized AllGather Token Dispatcher with GPU-resident metadata. +Inference-optimized AlltoAll Token Dispatcher with GPU-resident metadata. This implementation keeps tokens_per_expert GPU-resident to enable use of torch._grouped_mm without host synchronization. + +Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) +on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. """ import torch @@ -16,11 +19,15 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.tensor_parallel import ( - gather_from_sequence_parallel_region, - reduce_scatter_to_sequence_parallel_region, +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep +from megatron.core.inference.communication.torch_symm_triton import ( + multimem_all_gather_fused, + multimem_reduce_scatter, ) +import logging + class InferenceAllGatherTokenDispatcher(MoEAllGatherTokenDispatcher): """ Inference-optimized AllGather token dispatcher. @@ -41,6 +48,15 @@ def __init__( config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None, ) -> None: + """ + Initialize the inference AllGather token dispatcher. + + Args: + num_local_experts: Number of experts on this rank. + local_expert_indices: Global indices of experts on this rank. + config: Transformer configuration. + pg_collection: Process group collection for distributed ops. + """ super().__init__( num_local_experts=num_local_experts, local_expert_indices=local_expert_indices, @@ -49,23 +65,133 @@ def __init__( ) self.topk = config.moe_router_topk + # Cache for NVLS eligibility + self._nvls_eligible = None + + def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: + """ + Check if we can use NVLS (latency-optimized) collectives. + Requirements: BF16 dtype, Hopper+ GPU (SM >= 9). + """ + is_bf16 = x.dtype == torch.bfloat16 + is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 + return is_bf16 and is_hopper_or_newer + + def _maybe_allocate_ag_buffers( + self, + routing_map: torch.Tensor, + probs: torch.Tensor, + hidden_states: torch.Tensor, + ) -> dict: + """ + Allocate a single symmetric memory buffer for all-gather outputs of + routing_map, probs and hidden_states. Returns sliced views for each. + + Returns dict with: + - "handle": symmetric memory handle (or None if unavailable) + - "routing_map" / "routing_map_offset": raw byte view and byte offset + - "probs" / "probs_offset": raw byte view and byte offset + - "hidden_states" / "hidden_states_offset": raw byte view and byte offset + """ + _NONE = { + "handle": None, + "routing_map": None, "routing_map_offset": 0, + "probs": None, "probs_offset": 0, + "hidden_states": None, "hidden_states_offset": 0, + } + + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + topk = probs.size(-1) + hidden_dim = hidden_states.size(-1) + + result = get_global_symmetric_memory_buffer_ep().maybe_get_tensors([ + (global_tokens * topk, routing_map.dtype), + (global_tokens * topk, probs.dtype), + (global_tokens * hidden_dim, hidden_states.dtype), + ]) + + if result["handle"] is None: + return _NONE + + (rm_buf, rm_off), (p_buf, p_off), (hs_buf, hs_off) = result["tensors"] + return { + "handle": result["handle"], + "routing_map": rm_buf, "routing_map_offset": rm_off, + "probs": p_buf, "probs_offset": p_off, + "hidden_states": hs_buf, "hidden_states_offset": hs_off, + } + + def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: + """ + Allocate symmetric memory buffer for reduce-scatter input. + Input shape matches x (the unpermuted hidden states). + """ + symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( + list(x.size()), dtype=x.dtype + ) + return symm_mem_buffer + def token_dispatch(self, hidden_states, probs): """ - Gathers tokens from all EP ranks using NCCL AllGather. + Gathers tokens from all EP ranks using AllGather. + + Uses latency-optimized NVLS multimem_all_gather for routing_map, probs and hidden_states + on Hopper+ GPUs with BF16. Falls back to NCCL otherwise. """ if self.ep_size == 1: return hidden_states, probs - - self.routing_map = gather_from_sequence_parallel_region( - self.routing_map, group=self.tp_ep_group - ) - probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) - hidden_states = gather_from_sequence_parallel_region( - hidden_states, group=self.tp_ep_group - ) + + # Check NVLS eligibility + nvls_eligible = self._check_nvls_eligibility(hidden_states) + ag_buffers = None + + if nvls_eligible: + ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) + + can_use_nvls = nvls_eligible and ag_buffers["handle"] is not None + + if can_use_nvls: + # Capture shapes for reshaping after all-gather + # Output shape: [local_tokens * ep_size, dim] + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + topk = probs.size(1) + hidden_dim = hidden_states.size(1) + routing_map_dtype = self.routing_map.dtype + probs_dtype = probs.dtype + hidden_dtype = hidden_states.dtype + + # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors + multimem_all_gather_fused( + ag_buffers["routing_map"].view(torch.bfloat16), + self.routing_map.view(torch.bfloat16), + ag_buffers["routing_map_offset"], + ag_buffers["probs"].view(torch.bfloat16), + probs.view(torch.bfloat16), + ag_buffers["probs_offset"], + ag_buffers["hidden_states"].view(torch.bfloat16), + hidden_states.view(torch.bfloat16), + ag_buffers["hidden_states_offset"], + ag_buffers["handle"], + ) + self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) + probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) + hidden_states = ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) + else: + # Fallback to NCCL for all tensors + with torch.no_grad(): + self.routing_map = gather_from_sequence_parallel_region( + self.routing_map, group=self.tp_ep_group + ) + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) + hidden_states = gather_from_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) return hidden_states, probs + def dispatch_postprocess(self, hidden_states, probs): """Pass-through: returns unpermuted inputs and routing_map for InferenceGroupedMLP.""" return hidden_states, self.routing_map, probs @@ -76,12 +202,44 @@ def combine_preprocess(self, expert_output): def token_combine(self, hidden_states): """ - Combines expert outputs using NCCL Reduce-Scatter. + Combines expert outputs using Reduce-Scatter. + + Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs with BF16 + when symmetric memory is available. Falls back to NCCL via superclass otherwise. + + Args: + hidden_states: [global_tokens, hidden_dim] tensor to reduce-scatter + + Returns: + [local_tokens, hidden_dim] tensor after reduce-scatter """ if self.ep_size == 1: return hidden_states - return reduce_scatter_to_sequence_parallel_region( - hidden_states, group=self.tp_ep_group - ) + # Check NVLS eligibility and try to allocate symmetric memory + nvls_eligible = self._check_nvls_eligibility(hidden_states) + rs_buffer = None + + if nvls_eligible: + rs_buffer = self._maybe_allocate_rs_buffer(hidden_states) + + can_use_nvls = nvls_eligible and rs_buffer["handle"] is not None + + if can_use_nvls: + # Copy input to symmetric memory for reduce-scatter + rs_buffer["tensor"].copy_(hidden_states) + + # Allocate output tensor + output_shape = list(hidden_states.size()) + output_shape[0] = hidden_states.size(0) // self.ep_size + output = torch.empty( + output_shape, dtype=hidden_states.dtype, device=hidden_states.device + ) + + # Use latency-optimized NVLS reduce-scatter + multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) + return output + else: + # Fallback to NCCL via superclass + return super().token_combine(hidden_states) From 3deb50ce15eea9edf4ee7d4633384348dd23fa8c Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Wed, 25 Feb 2026 23:58:44 -0800 Subject: [PATCH 61/92] Revert "remove changes related to symm mem comms" This reverts commit b89600b93c662fbd07575c8a2394dc1023b753d1. --- .../torch_symm_triton/__init__.py | 2 +- .../torch_symm_triton/collectives.py | 266 +++++++++++------- megatron/core/parallel_state.py | 42 ++- .../core/tensor_parallel/inference_layers.py | 6 +- megatron/core/utils.py | 38 +++ 5 files changed, 234 insertions(+), 120 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index ca58663d9ec..586e913541e 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from .collectives import multimem_all_gather, multimem_reduce_scatter +from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter from .fused_collectives import fused_multimem_rs_add_norm_ag diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 4bc4dbde42b..eb48dae7d0f 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -25,41 +25,41 @@ from .multimem_asm import ld_128, st_128 from .utils import get_flat_tid, sync_threads +# ── Triton kernels ───────────────────────────────────────────────────────── @triton.jit -def _multimem_all_gather_kernel( - local_ptr, - multicast_ptr, - signal_pad_ptrs, - numel, - BLOCK_SIZE: tl.constexpr, - NUMEL_PER_THREAD: tl.constexpr, - RANK: tl.constexpr, - WORLD_SIZE: tl.constexpr, -): - """ - Triton kernel to perform multicast all-gather over nvlink using multimem instructions. +def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE): """ - # an all-gather is simply a multicast store operation - # we only need a barrier at the end to ensure visibility of writes + Core all-gather phase: load from local memory, multicast-store to symmetric buffer. + This is the building block for both single-tensor and fused multi-tensor all-gathers. + + Each thread handles 128-bit (NUMEL_PER_THREAD elements) at a time. + byte_offset locates the tensor within the multicast buffer. + NOTE: When numel is not divisible by (NUMEL_PER_THREAD * WORLD_SIZE), the kernel + rounds up via cdiv and may read/write up to 15 bytes past the logical tensor end. + This is safe because PyTorch's CUDA caching allocator guarantees a minimum block + size of 512 bytes (kMinBlockSize in CUDACachingAllocator.cpp), so small tensors + always have sufficient backing memory. + """ pid = tl.program_id(axis=0) tid = get_flat_tid() - # From this point on, we pretend each element is 128-bit - numel = numel // NUMEL_PER_THREAD - numel_per_rank = tl.cdiv(numel, WORLD_SIZE) + numel_128 = numel // NUMEL_PER_THREAD + numel_per_rank = tl.cdiv(numel_128, WORLD_SIZE) block_start = pid * BLOCK_SIZE while block_start < numel_per_rank: offsets = block_start + tid mask = offsets < numel_per_rank - # Each pointer points to a 128-bit bit pack - # RANK * numel_per_rank -> brings us to the start of our rank's segment - # offsets -> brings us to the right offset within our rank's segment + # byte_offset // 8 -> converts byte offset to uint64 offset + # RANK * numel_per_rank -> start of our rank's segment + # * 2 -> each 128-bit pack is 2 uint64s multicast_ptrs = ( - multicast_ptr.to(tl.pointer_type(tl.uint64)) + (RANK * numel_per_rank + offsets) * 2 + multicast_ptr.to(tl.pointer_type(tl.uint64)) + + byte_offset // 8 + + (RANK * numel_per_rank + offsets) * 2 ) local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) @@ -67,66 +67,54 @@ def _multimem_all_gather_kernel( block_start += tl.num_programs(axis=0) * BLOCK_SIZE + +@triton.jit +def _multimem_all_gather_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + numel, + byte_offset, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """Single-tensor multicast all-gather kernel.""" + _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) sync_threads() - symm_mem_sync( - signal_pad_ptrs, - None, - RANK, - WORLD_SIZE, - hasPreviousMemAccess=True, - hasSubsequentMemAccess=True, - ) + symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, + hasPreviousMemAccess=True, hasSubsequentMemAccess=True) -def multimem_all_gather( - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - symm_mem_hdl: _SymmetricMemory, - **kwargs, -) -> torch.Tensor: +@triton.jit +def _multimem_all_gather_3_kernel( + local_ptr_0, local_ptr_1, local_ptr_2, + multicast_ptr, + signal_pad_ptrs, + numel_0, byte_offset_0, + numel_1, byte_offset_1, + numel_2, byte_offset_2, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): """ - Calls a multicast all-gather triton kernel on the given tensor. - Output tensor must be a symmetric memory buffer. - Input tensor can be a regular torch tensor - Arguments: - output_tensor: torch.Tensor - output tensor to be all-gathered into - input_tensor: torch.Tensor - input tensor to be all-gathered from - symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for output_tensor - Returns: - torch.Tensor - all-gathered tensor, which is output_tensor + Fused 3-tensor multicast all-gather. Processes three tensors in sequence + then synchronizes once, eliminating 2 kernel launches and 2 barriers + compared to three separate multimem_all_gather calls. """ - assert HAVE_TRITON, "Triton is required for multimem all-gather." - - config = { - "max_num_blocks": kwargs.get("max_num_blocks", 24), - "num_warps": kwargs.get("num_warps", 32), - "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), - } - assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - numel_per_thread = 128 // (input_tensor.element_size() * 8) - - assert ( - output_tensor.numel() % numel_per_thread == 0 - ), "The number of elements must be 128-bit aligned." - - num_threads = triton.cdiv(output_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) - - _multimem_all_gather_kernel[(num_blocks, 1, 1)]( - input_tensor.data_ptr(), - symm_mem_hdl.multicast_ptr, - symm_mem_hdl.signal_pad_ptrs_dev, - numel=output_tensor.numel(), - BLOCK_SIZE=config["BLOCK_SIZE"], - NUMEL_PER_THREAD=numel_per_thread, - RANK=symm_mem_hdl.rank, - WORLD_SIZE=symm_mem_hdl.world_size, - num_warps=config["num_warps"], - ) - - return output_tensor - + _ag_phase(local_ptr_0, multicast_ptr, byte_offset_0, numel_0, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + _ag_phase(local_ptr_1, multicast_ptr, byte_offset_1, numel_1, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + _ag_phase(local_ptr_2, multicast_ptr, byte_offset_2, numel_2, + BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + sync_threads() + symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, + hasPreviousMemAccess=True, hasSubsequentMemAccess=True) @triton.jit def _multimem_reduce_scatter_kernel( @@ -174,48 +162,118 @@ def _multimem_reduce_scatter_kernel( block_start += tl.num_programs(axis=0) * BLOCK_SIZE +# ── Python wrappers ───────────────────────────────────────────────────────── -def multimem_reduce_scatter( +_DEFAULT_KERNEL_CONFIG = { + "max_num_blocks": 128, + "num_warps": 32, + "BLOCK_SIZE": 1024, +} + + +def _kernel_launch_config(element_size: int, max_numel: int, world_size: int, **kwargs): + """Compute kernel launch config shared by all collective wrappers. + + Args: + element_size: bytes per element (e.g. 2 for bf16). + max_numel: largest tensor numel (determines grid size). + world_size: number of ranks. + + Returns: + (numel_per_thread, num_blocks, config) tuple. + """ + config = {k: kwargs.get(k, v) for k, v in _DEFAULT_KERNEL_CONFIG.items()} + numel_per_thread = 128 // (element_size * 8) + num_threads = triton.cdiv(max_numel // numel_per_thread, world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + return numel_per_thread, num_blocks, config + + +def multimem_all_gather( output_tensor: torch.Tensor, input_tensor: torch.Tensor, symm_mem_hdl: _SymmetricMemory, + byte_offset: int = 0, **kwargs, ) -> torch.Tensor: """ - Calls a multicast reduce-scatter triton kernel on the given tensor. - Input tensor must be a symmetric memory buffer. - Output tensor can be a regular torch tensor - Arguments: - output_tensor: torch.Tensor - output tensor to be reduce-scattered into - input_tensor: torch.Tensor - input tensor to be reduce-scattered from - symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for input_tensor - **kwargs: Additional keyword arguments for kernel configuration: - max_num_blocks (int, optional): The maximum number of blocks to launch. - num_warps (int, optional): The number of warps per block. - BLOCK_SIZE (int, optional): The BLOCK_SIZE parameter for the kernel. - Returns: - torch.Tensor - reduce-scattered tensor, which is output_tensor + Multicast all-gather for a single tensor. + Output tensor must be a symmetric memory buffer. + Input tensor can be a regular torch tensor. """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." - assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." + numel_per_thread, num_blocks, config = _kernel_launch_config( + input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs, + ) + _multimem_all_gather_kernel[(num_blocks, 1, 1)]( + input_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel=output_tensor.numel(), + byte_offset=byte_offset, + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) - config = { - "max_num_blocks": kwargs.get("max_num_blocks", 24), - "num_warps": kwargs.get("num_warps", 32), - "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), - } + return output_tensor - assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - numel_per_thread = 128 // (output_tensor.element_size() * 8) - assert ( - input_tensor.numel() % numel_per_thread == 0 - ), "The number of elements must be 128-bit aligned." +def multimem_all_gather_fused( + output_0: torch.Tensor, input_0: torch.Tensor, byte_offset_0: int, + output_1: torch.Tensor, input_1: torch.Tensor, byte_offset_1: int, + output_2: torch.Tensor, input_2: torch.Tensor, byte_offset_2: int, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> None: + """ + Fused 3-tensor multicast all-gather. Equivalent to calling multimem_all_gather + three times but with a single kernel launch and a single barrier. + + All tensors must share the same symmetric memory handle. + """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." + + max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) + numel_per_thread, num_blocks, config = _kernel_launch_config( + input_0.element_size(), max_numel, symm_mem_hdl.world_size, **kwargs, + ) + _multimem_all_gather_3_kernel[(num_blocks, 1, 1)]( + input_0.data_ptr(), input_1.data_ptr(), input_2.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel_0=output_0.numel(), byte_offset_0=byte_offset_0, + numel_1=output_1.numel(), byte_offset_1=byte_offset_1, + numel_2=output_2.numel(), byte_offset_2=byte_offset_2, + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) - num_threads = triton.cdiv(input_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) +def multimem_reduce_scatter( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> torch.Tensor: + """ + Multicast reduce-scatter for a single tensor. + Input tensor must be a symmetric memory buffer. + Output tensor can be a regular torch tensor. + """ + assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." + assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + + numel_per_thread, num_blocks, config = _kernel_launch_config( + output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs, + ) _multimem_reduce_scatter_kernel[(num_blocks, 1, 1)]( output_tensor.data_ptr(), symm_mem_hdl.multicast_ptr, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 087cbe7e152..b571a357fad 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -140,8 +140,9 @@ # Memory buffers to avoid dynamic memory allocation _GLOBAL_MEMORY_BUFFER = None -# Global symmetric memory buffer for inference -_GLOBAL_SYMMETRIC_MEMORY_BUFFER = None +# Global symmetric memory buffers for inference +_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None +_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None # List of all process groups # Used for updating the timeout for all process groups @@ -2016,14 +2017,20 @@ def _set_global_memory_buffer(): def _set_global_symmetric_memory_buffer(): """Initialize global buffer.""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER is None, "global memory buffer is already initialized" + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is None, "global symmetric memory buffer for TP is already initialized" + assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is None, "global symmetric memory buffer for EP is already initialized" - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = GlobalSymmetricMemoryBuffer( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = GlobalSymmetricMemoryBuffer( size_in_mb=256, # todo: set from an argument? process_group=get_tensor_model_parallel_group(), ) + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = GlobalSymmetricMemoryBuffer( + size_in_mb=256, # todo: set from an argument? + process_group=get_expert_model_parallel_group(), + ) + def get_global_memory_buffer(): """Return the global GlobalMemoryBuffer object""" @@ -2031,12 +2038,19 @@ def get_global_memory_buffer(): return _GLOBAL_MEMORY_BUFFER -def get_global_symmetric_memory_buffer(): +def get_global_symmetric_memory_buffer_tp(): + """Return the global GlobalSymmetricMemoryBuffer object""" + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is not None + ), "global symmetric memory buffer is not initialized" + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP + +def get_global_symmetric_memory_buffer_ep(): """Return the global GlobalSymmetricMemoryBuffer object""" assert ( - _GLOBAL_SYMMETRIC_MEMORY_BUFFER is not None + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is not None ), "global symmetric memory buffer is not initialized" - return _GLOBAL_SYMMETRIC_MEMORY_BUFFER + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP def destroy_global_memory_buffer(): @@ -2047,8 +2061,9 @@ def destroy_global_memory_buffer(): def destroy_global_symmetric_memory_buffer(): """Sets the global symmetric memory buffer to None""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None def get_all_ranks(): @@ -2129,8 +2144,11 @@ def destroy_model_parallel(): global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None + + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None global _DATA_PARALLEL_GROUP_GLOO if ( diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index 0addc64a65f..fcf882b6818 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -16,7 +16,7 @@ from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor from megatron.core.inference.quantization.utils import mm_mxfp8 from megatron.core.model_parallel_config import ModelParallelConfig -from megatron.core.parallel_state import get_global_symmetric_memory_buffer +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_tp from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -120,7 +120,7 @@ def _maybe_allocate_symmetric_buffer(self, x: torch.Tensor): """ symm_mem_buffer_dims = list(x.size()) symm_mem_buffer_dims[0] *= self.tp_size - symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) return symm_mem_buffer @@ -245,7 +245,7 @@ def _matmul_reduce_scatter(self, x, residual=None): # Remove batch dimension for FlashInfer mxfp8 del symm_mem_buffer_dims[1] symm_mem_buffer_dims[-1] = self.weight.size(0) - symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None diff --git a/megatron/core/utils.py b/megatron/core/utils.py index c0533bd1fab..ed31b77ba04 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -728,6 +728,44 @@ def _allocate(self, numel, dtype) -> torch.Tensor: required_bytes = numel * torch.tensor([], dtype=dtype).element_size() return self.symm_buffer[0:required_bytes].view(dtype).view(numel) + def maybe_get_tensors(self, tensor_specs, alignment=16): + """ + Pack multiple tensors contiguously in the symmetric buffer with alignment. + + Each tensor's starting offset is aligned to `alignment` bytes (default 16 + for 128-bit multimem access). + + Args: + tensor_specs: list of (numel, dtype) tuples. + alignment: byte alignment for each tensor's start offset (default 16). + + Returns: + {"handle": None, "tensors": None} if unavailable or insufficient space. + {"handle": symm_mem_hdl, "tensors": [(raw_byte_view, byte_offset), ...]} + on success, where raw_byte_view is a uint8 slice of the buffer. + """ + _NONE_RESULT = {"handle": None, "tensors": None} + if self.symm_mem_hdl is None: + return _NONE_RESULT + + # Compute aligned byte sizes and running offsets + slices = [] + current_offset = 0 + for numel, dtype in tensor_specs: + nbytes = numel * torch.tensor([], dtype=dtype).element_size() + aligned_nbytes = ((nbytes + alignment - 1) // alignment) * alignment + slices.append((current_offset, nbytes)) + current_offset += aligned_nbytes + + if not self._can_allocate(current_offset, torch.uint8): + return _NONE_RESULT + + tensors = [] + for offset, nbytes in slices: + tensors.append((self.symm_buffer[offset : offset + nbytes], offset)) + + return {"handle": self.symm_mem_hdl, "tensors": tensors} + def maybe_get_tensor(self, tensor_shape, dtype): """ Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape. From 7aff1164eafd2ede3de273ea99ccbf9eee218502 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 00:23:48 -0800 Subject: [PATCH 62/92] bring back NVLS collectives --- .../communication/torch_symm_triton/collectives.py | 8 ++++++++ .../core/transformer/moe/token_dispatcher_inference.py | 8 ++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index eb48dae7d0f..c54934f306c 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -202,6 +202,8 @@ def multimem_all_gather( Input tensor can be a regular torch tensor. """ assert HAVE_TRITON, "Triton is required for multimem all-gather." + for x in [input_tensor, output_tensor]: + assert x.element_size() * x.numel() % 16 == 0, "Tensor size must be a multiple of 16 bytes for NVLS." numel_per_thread, num_blocks, config = _kernel_launch_config( input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs, @@ -237,7 +239,11 @@ def multimem_all_gather_fused( """ assert HAVE_TRITON, "Triton is required for multimem all-gather." + for x in [input_0, input_1, input_2, output_0, output_1, output_2]: + assert x.element_size() * x.numel() % 16 == 0, "Tensor size must be a multiple of 16 bytes for NVLS." + max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) + numel_per_thread, num_blocks, config = _kernel_launch_config( input_0.element_size(), max_numel, symm_mem_hdl.world_size, **kwargs, ) @@ -270,6 +276,8 @@ def multimem_reduce_scatter( assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + for x in [input_tensor, output_tensor]: + assert x.element_size() * x.numel() % 16 == 0, "Tensor size must be a multiple of 16 bytes for NVLS." numel_per_thread, num_blocks, config = _kernel_launch_config( output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs, diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index aace1827be8..c0a130a8496 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -73,9 +73,9 @@ def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: Check if we can use NVLS (latency-optimized) collectives. Requirements: BF16 dtype, Hopper+ GPU (SM >= 9). """ - is_bf16 = x.dtype == torch.bfloat16 is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 - return is_bf16 and is_hopper_or_newer + num_bytes = x.element_size() * x.numel() + return is_hopper_or_newer and num_bytes % 16 == 0 def _maybe_allocate_ag_buffers( self, @@ -141,9 +141,9 @@ def token_dispatch(self, hidden_states, probs): """ if self.ep_size == 1: return hidden_states, probs - + # Check NVLS eligibility - nvls_eligible = self._check_nvls_eligibility(hidden_states) + nvls_eligible = self._check_nvls_eligibility(hidden_states) and self._check_nvls_eligibility(probs) and self._check_nvls_eligibility(self.routing_map) ag_buffers = None if nvls_eligible: From 2f8cf3e7afd80dcf1eb9273a86e20fa28bba77c7 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 00:50:12 -0800 Subject: [PATCH 63/92] add kill-switch for nvls --- .../transformer/moe/token_dispatcher_inference.py | 15 +++++++++------ megatron/core/transformer/transformer_config.py | 3 +++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index c0a130a8496..7c2f15a0ab7 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -10,6 +10,7 @@ on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. """ +from megatron.core.tensor_parallel.mappings import reduce_scatter_to_sequence_parallel_region import torch from typing import List, Optional @@ -19,7 +20,7 @@ ) from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.tensor_parallel import gather_from_sequence_parallel_region +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep from megatron.core.inference.communication.torch_symm_triton import ( multimem_all_gather_fused, @@ -65,8 +66,7 @@ def __init__( ) self.topk = config.moe_router_topk - # Cache for NVLS eligibility - self._nvls_eligible = None + self.triton_nvls_kernels_allowed = not self.config.inference_disable_triton_nvls_kernels def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: """ @@ -75,7 +75,7 @@ def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: """ is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 num_bytes = x.element_size() * x.numel() - return is_hopper_or_newer and num_bytes % 16 == 0 + return self.triton_nvls_kernels_allowed and is_hopper_or_newer and num_bytes % 16 == 0 def _maybe_allocate_ag_buffers( self, @@ -240,6 +240,9 @@ def token_combine(self, hidden_states): multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) return output else: - # Fallback to NCCL via superclass - return super().token_combine(hidden_states) + # Fallback to NCCL + hidden_states = reduce_scatter_to_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) + return hidden_states diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index f1bd5758202..c9ee868f53b 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -858,6 +858,9 @@ class TransformerConfig(ModelParallelConfig): inference_fuse_tp_communication: bool = False """ If true, uses a fused reduce-scatter-residual-norm-allgather kernel during inference. """ + inference_disable_triton_nvls_kernels: bool = False + """ If true, disables the use of Triton NVLS kernels during inference. """ + mrope_section: Optional[List[int]] = None """ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. """ From 64bc241c7afbb6e390473abe3768acba7a549f37 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 13:08:17 -0800 Subject: [PATCH 64/92] cleanup nvls --- .../torch_symm_triton/__init__.py | 1 + .../torch_symm_triton/collectives.py | 23 +++++++---- .../communication/torch_symm_triton/utils.py | 23 +++++++++++ .../moe/token_dispatcher_inference.py | 40 ++++++++----------- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index 586e913541e..9654e8bc67c 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter +from .utils import are_tensors_nvls_eligible, is_device_nvls_capable from .fused_collectives import fused_multimem_rs_add_norm_ag diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index c54934f306c..3ba8598467f 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -23,7 +23,7 @@ from .barrier import symm_mem_sync from .multimem_asm import ld_128, st_128 -from .utils import get_flat_tid, sync_threads +from .utils import get_flat_tid, are_tensors_nvls_eligible, sync_threads # ── Triton kernels ───────────────────────────────────────────────────────── @@ -202,8 +202,10 @@ def multimem_all_gather( Input tensor can be a regular torch tensor. """ assert HAVE_TRITON, "Triton is required for multimem all-gather." - for x in [input_tensor, output_tensor]: - assert x.element_size() * x.numel() % 16 == 0, "Tensor size must be a multiple of 16 bytes for NVLS." + assert are_tensors_nvls_eligible(input_tensor), "Input tensor must be 16-byte divisible on Hopper+ for NVLS." + assert output_tensor.numel() % input_tensor.numel() == 0 and \ + output_tensor.numel() // input_tensor.numel() == symm_mem_hdl.world_size, \ + "Output numel must be exactly world_size * input numel for all-gather." numel_per_thread, num_blocks, config = _kernel_launch_config( input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs, @@ -238,9 +240,12 @@ def multimem_all_gather_fused( All tensors must share the same symmetric memory handle. """ assert HAVE_TRITON, "Triton is required for multimem all-gather." - - for x in [input_0, input_1, input_2, output_0, output_1, output_2]: - assert x.element_size() * x.numel() % 16 == 0, "Tensor size must be a multiple of 16 bytes for NVLS." + assert are_tensors_nvls_eligible(input_0, input_1, input_2), \ + "All input tensors must be 16-byte divisible on Hopper+ for NVLS." + for inp, out in [(input_0, output_0), (input_1, output_1), (input_2, output_2)]: + assert out.numel() % inp.numel() == 0 and \ + out.numel() // inp.numel() == symm_mem_hdl.world_size, \ + "Output numel must be exactly world_size * input numel for all-gather." max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) @@ -276,8 +281,10 @@ def multimem_reduce_scatter( assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - for x in [input_tensor, output_tensor]: - assert x.element_size() * x.numel() % 16 == 0, "Tensor size must be a multiple of 16 bytes for NVLS." + assert are_tensors_nvls_eligible(output_tensor), "Output tensor must be 16-byte divisible on Hopper+ for NVLS." + assert input_tensor.numel() % output_tensor.numel() == 0 and \ + input_tensor.numel() // output_tensor.numel() == symm_mem_hdl.world_size, \ + "Input numel must be exactly world_size * output numel for reduce-scatter." numel_per_thread, num_blocks, config = _kernel_launch_config( output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs, diff --git a/megatron/core/inference/communication/torch_symm_triton/utils.py b/megatron/core/inference/communication/torch_symm_triton/utils.py index 785481dfba6..6e1eee74da4 100644 --- a/megatron/core/inference/communication/torch_symm_triton/utils.py +++ b/megatron/core/inference/communication/torch_symm_triton/utils.py @@ -2,6 +2,7 @@ # Adapted from: https://github.com/meta-pytorch/kraken.git +import torch from unittest.mock import MagicMock from megatron.core.utils import null_decorator @@ -15,6 +16,28 @@ triton.jit = null_decorator + +def is_device_nvls_capable(device: torch.device) -> bool: + """Check if the device supports NVLS (multicast) collectives. Requires CUDA Hopper+ (SM >= 9).""" + return device.type == "cuda" and torch.cuda.get_device_properties(device).major >= 9 + + +def are_tensors_nvls_eligible(*tensors: torch.Tensor) -> bool: + """Check if tensors are eligible for NVLS (multicast) collectives. + + Requirements: + - Hopper+ GPU (SM >= 9) + - All tensor byte sizes are divisible by 16 (128-bit), since NVLS + kernels process data in 128-bit chunks. + """ + if not tensors: + return False + return is_device_nvls_capable(tensors[0].device) and all( + t.element_size() * t.numel() % 16 == 0 for t in tensors + ) + + + @triton.jit def get_tid(): """ diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 7c2f15a0ab7..479b265c6eb 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -10,7 +10,6 @@ on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. """ -from megatron.core.tensor_parallel.mappings import reduce_scatter_to_sequence_parallel_region import torch from typing import List, Optional @@ -23,12 +22,11 @@ from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep from megatron.core.inference.communication.torch_symm_triton import ( + are_tensors_nvls_eligible, multimem_all_gather_fused, multimem_reduce_scatter, ) -import logging - class InferenceAllGatherTokenDispatcher(MoEAllGatherTokenDispatcher): """ Inference-optimized AllGather token dispatcher. @@ -68,15 +66,6 @@ def __init__( self.triton_nvls_kernels_allowed = not self.config.inference_disable_triton_nvls_kernels - def _check_nvls_eligibility(self, x: torch.Tensor) -> bool: - """ - Check if we can use NVLS (latency-optimized) collectives. - Requirements: BF16 dtype, Hopper+ GPU (SM >= 9). - """ - is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 - num_bytes = x.element_size() * x.numel() - return self.triton_nvls_kernels_allowed and is_hopper_or_newer and num_bytes % 16 == 0 - def _maybe_allocate_ag_buffers( self, routing_map: torch.Tensor, @@ -142,13 +131,15 @@ def token_dispatch(self, hidden_states, probs): if self.ep_size == 1: return hidden_states, probs - # Check NVLS eligibility - nvls_eligible = self._check_nvls_eligibility(hidden_states) and self._check_nvls_eligibility(probs) and self._check_nvls_eligibility(self.routing_map) + # 1. Check inputs only: if inputs are 16-byte divisible, outputs (world_size * input) are too. + nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible(hidden_states, probs, self.routing_map) ag_buffers = None if nvls_eligible: + # 2. Now attempt to allocate symmetric memory buffers for all-gather outputs. If allocation fails, fallback to NCCL. ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) + # 3. Can use NVLS if eligible and buffers allocated successfully (handle is not None) can_use_nvls = nvls_eligible and ag_buffers["handle"] is not None if can_use_nvls: @@ -164,7 +155,7 @@ def token_dispatch(self, hidden_states, probs): # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors multimem_all_gather_fused( - ag_buffers["routing_map"].view(torch.bfloat16), + ag_buffers["routing_map"].view(torch.bfloat16), # .view does not change the underlying data self.routing_map.view(torch.bfloat16), ag_buffers["routing_map_offset"], ag_buffers["probs"].view(torch.bfloat16), @@ -216,8 +207,16 @@ def token_combine(self, hidden_states): if self.ep_size == 1: return hidden_states - # Check NVLS eligibility and try to allocate symmetric memory - nvls_eligible = self._check_nvls_eligibility(hidden_states) + # Compute output shape first — check NVLS eligibility on the output, + # since if the smaller output is 16-byte divisible, the input is too. + output_shape = list(hidden_states.size()) + output_shape[0] = hidden_states.size(0) // self.ep_size + output = torch.empty( + output_shape, dtype=hidden_states.dtype, device=hidden_states.device + ) + + # Check output only: if output is 16-byte divisible, input (world_size * output) is too. + nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible(output) rs_buffer = None if nvls_eligible: @@ -229,13 +228,6 @@ def token_combine(self, hidden_states): # Copy input to symmetric memory for reduce-scatter rs_buffer["tensor"].copy_(hidden_states) - # Allocate output tensor - output_shape = list(hidden_states.size()) - output_shape[0] = hidden_states.size(0) // self.ep_size - output = torch.empty( - output_shape, dtype=hidden_states.dtype, device=hidden_states.device - ) - # Use latency-optimized NVLS reduce-scatter multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) return output From 81054b9a7a534251b60e9ff3972ca67a77ef105a Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 13:56:56 -0800 Subject: [PATCH 65/92] minor changes --- .../core/tensor_parallel/inference_layers.py | 43 +++++++++++-------- .../core/transformer/transformer_config.py | 6 +++ 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index fcf882b6818..cb367e43bd6 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -9,6 +9,7 @@ TERowParallelLinear, ) from megatron.core.inference.communication.torch_symm_triton import ( + are_tensors_nvls_eligible, fused_multimem_rs_add_norm_ag, multimem_all_gather, multimem_reduce_scatter, @@ -108,6 +109,8 @@ def __init__( config.sequence_parallel ), "--transformer-impl=inference_optimized requires --sequence-parallel" + self.triton_nvls_kernels_allowed = not config.inference_disable_triton_nvls_kernels + # Boolean to be toggled externally for skipping norm and all-gather. # This is used when enabling fused reduce-scatter + add + rms-norm + all-gather # in tensor parallelism. In this case, the preceeding RowParallelLinear layer @@ -133,16 +136,14 @@ def _all_gather(self, x: torch.Tensor, symm_mem_buffer: dict) -> None: if self.tp_size == 1: return x - # 1. check if bf16 - is_bf16 = x.dtype == torch.bfloat16 - # 2. check if hopper or newer - is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 - # 3. check if symmetric memory buffer is available - has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None - can_use_custom_nvls_collectives = ( - is_bf16 and is_hopper_or_newer and has_enough_symmetric_memory + # Check input only: if input is 16-byte divisible, the output + # (world_size * input) is too. + can_use_nvls = ( + self.triton_nvls_kernels_allowed + and are_tensors_nvls_eligible(x) + and symm_mem_buffer["handle"] is not None ) - if can_use_custom_nvls_collectives: + if can_use_nvls: # do multimem all gather multimem_all_gather(symm_mem_buffer["tensor"], x, symm_mem_buffer["handle"]) return symm_mem_buffer["tensor"] @@ -221,6 +222,10 @@ def __init__( config.sequence_parallel ), "--transformer-impl=inference_optimized requires --sequence-parallel" + self.triton_nvls_kernels_allowed = not getattr( + config, 'inference_disable_triton_nvls_kernels', False + ) + # Placeholder for next layer norm weights for fused # reduce-scatter + add + rms-norm + all-gather self.next_layer_norm_weights = None @@ -233,13 +238,7 @@ def _matmul_reduce_scatter(self, x, residual=None): and perform an NVLS multicast reduce-scatter. If that is not possible, it will revert to torch.dist (NCCL) reduce-scatter. """ - # 1. check if bf16 - is_bf16 = x.dtype == torch.bfloat16 - # 2. check if mxfp8 use_mxfp8 = self.config.fp8_recipe == "mxfp8" - # 3. check if hopper or newer - is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 - # 4. attempt to ask for symmetric memory symm_mem_buffer_dims = list(x.size()) if use_mxfp8: # Remove batch dimension for FlashInfer mxfp8 @@ -248,12 +247,18 @@ def _matmul_reduce_scatter(self, x, residual=None): symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) - has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None - can_use_custom_nvls_collectives = ( - is_bf16 and is_hopper_or_newer and has_enough_symmetric_memory + + # RS requires bf16 (hardware multimem reduce is bf16-only). + # Check the matmul output shape: if it is NVLS-eligible, the RS output + # (world_size times smaller on dim 0) is too. + can_use_nvls = ( + self.triton_nvls_kernels_allowed + and x.dtype == torch.bfloat16 + and are_tensors_nvls_eligible(x) + and symm_mem_buffer["handle"] is not None ) - if can_use_custom_nvls_collectives: + if can_use_nvls: # Write output of matmul directly onto the symmetric memory buffer x = _apply_linear(x, self.weight, self.config, out=symm_mem_buffer["tensor"]) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index c9ee868f53b..b29abefd4ec 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -2052,6 +2052,12 @@ def __post_init__(self): "for inference_optimized transformer implementation." ) + if self.inference_disable_triton_nvls_kernels: + assert self.transformer_impl == "inference_optimized", ( + "inference_disable_triton_nvls_kernels is only supported " + "for inference_optimized transformer implementation." + ) + if self.batch_invariant_mode: assert ( self.attention_backend == AttnBackend.flash From 6f5bf1294c792b16cf12c8245e8634e21b8e0f36 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 14:06:10 -0800 Subject: [PATCH 66/92] only do set is inference cg iteration from the engine --- .../core/inference/engines/dynamic_engine.py | 16 +++++++++++++++- .../text_generation_controller.py | 5 +---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index e96320499df..c728e07bae6 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -41,7 +41,7 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.inference.utils import Counter, await_process_call +from megatron.core.inference.utils import Counter, await_process_call, set_is_inference_cuda_graphed_iteration_for_ep_inference from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.enums import CudaGraphScope @@ -291,6 +291,16 @@ def create_cuda_graphs(self, reset_context: bool = True): for graph in context.cuda_graph_batch_dimensions_list: logging.info(graph) + # Enable inference dispatcher for EP during graph capture + model_config = controller.inference_wrapped_model.model.config + is_inference_optimized_ep = ( + model_config.transformer_impl == "inference_optimized" + and model_config.expert_model_parallel_size > 1 + ) + if is_inference_optimized_ep: + unwrapped_model = controller.inference_wrapped_model.model + set_is_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model, True) + tbar = enumerate(context.cuda_graph_batch_dimensions_list) if HAVE_TQDM: tbar = tqdm(tbar, total=len(context.cuda_graph_batch_dimensions_list)) @@ -318,6 +328,10 @@ def create_cuda_graphs(self, reset_context: bool = True): context.reset() + # Disable inference dispatcher after graph capture + if is_inference_optimized_ep: + set_is_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model, False) + # Memory usage. time_end = time.time() mem_stats_end = torch.cuda.memory_stats() diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 6e36dac1f2a..cc75a7af856 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -25,7 +25,7 @@ AbstractModelInferenceWrapper, ) from megatron.core.inference.sampling_params import SamplingParams -from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding, set_is_inference_cuda_graphed_iteration_for_ep_inference +from megatron.core.inference.utils import get_attention_mask, set_decode_expert_padding from megatron.core.models.multimodal.llava_model import LLaVAModel from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer.enums import CudaGraphScope @@ -543,9 +543,6 @@ def _dynamic_step_context_init( else: set_decode_expert_padding(unwrapped_model, False) - if is_inference_optimized and model_config.expert_model_parallel_size > 1: - set_is_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model, context.using_cuda_graph_this_step()) - # initialize symmetric memory if needed if model_config.transformer_impl == "inference_optimized": context.maybe_initialize_symmetric_memory() From e533f431b3b1a59639b94707f0037d9d228377ec Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 14:13:39 -0800 Subject: [PATCH 67/92] more cleanup --- megatron/core/transformer/moe/moe_layer.py | 40 ++++++++-------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 42f06ef9850..2df7cb9d618 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -300,24 +300,25 @@ def _setup_inference_mode(self, pg_collection): ) def set_is_inference_cuda_graphed_iteration(self, set_to: bool): - """Toggle CUDA-graphed iteration mode on this layer, its router, and its experts.""" + """Toggle CUDA-graphed iteration mode on this layer, its router, and its experts. + + When enabled, swaps in the inference-optimized token dispatcher and disables + shared expert overlap. When disabled, restores the standard dispatcher. + """ self.is_inference_cuda_graphed_iteration = set_to if hasattr(self.router, "set_is_inference_cuda_graphed_iteration"): self.router.set_is_inference_cuda_graphed_iteration(set_to) if hasattr(self.experts, "set_is_inference_cuda_graphed_iteration"): self.experts.set_is_inference_cuda_graphed_iteration(set_to) - def _activate_inference_token_dispatcher(self): - """Swap in the inference-optimized token dispatcher.""" - self._saved_token_dispatcher = self.token_dispatcher - self.token_dispatcher = self._inference_token_dispatcher - self._saved_shared_expert_overlap = self.shared_expert_overlap - self.shared_expert_overlap = False - - def _deactivate_inference_token_dispatcher(self): - """Restore the standard token dispatcher.""" - self.token_dispatcher = self._saved_token_dispatcher - self.shared_expert_overlap = self._saved_shared_expert_overlap + if set_to and self._inference_token_dispatcher is not None: + self._saved_token_dispatcher = self.token_dispatcher + self.token_dispatcher = self._inference_token_dispatcher + self._saved_shared_expert_overlap = self.shared_expert_overlap + self.shared_expert_overlap = False + elif not set_to and hasattr(self, "_saved_token_dispatcher"): + self.token_dispatcher = self._saved_token_dispatcher + self.shared_expert_overlap = self._saved_shared_expert_overlap @maybe_skip_or_early_return_by_cudagraph("route") @@ -463,16 +464,6 @@ def forward( if padding_mask is not None: padding_mask = padding_mask.transpose(0, 1).bool() - # Swap in inference-optimized dispatcher for CUDA-graphed inference iterations - _use_inference_dispatcher = ( - self.config.transformer_impl == "inference_optimized" - and not self.training - and self.is_inference_cuda_graphed_iteration - and self._inference_token_dispatcher is not None - ) - if _use_inference_dispatcher: - self._activate_inference_token_dispatcher() - # MoE forward: route -> dispatch -> compute -> combine def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): try: @@ -517,7 +508,7 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): return output, mlp_bias - if self.moe_layer_recompute and not _use_inference_dispatcher: + if self.moe_layer_recompute and self.training: if self.config.fp8 or self.config.fp4: outputs = te_checkpoint( custom_forward, @@ -535,9 +526,6 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): else: outputs = custom_forward(hidden_states, intermediate_tensors, padding_mask) - if _use_inference_dispatcher: - self._deactivate_inference_token_dispatcher() - return outputs def backward_dw(self, routed_experts: bool = True, shared_experts: bool = False): From 7f4bc3291a023891821ff498484355e8ced4d1dc Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 16:04:07 -0800 Subject: [PATCH 68/92] resolve flashinfer activations, small bugfix, and no use flashinfer for gated activations --- .../core/inference/batch_dimensions_utils.py | 4 +-- .../inference/contexts/dynamic_context.py | 7 +++- megatron/core/transformer/moe/experts.py | 32 ++++++++++++++++--- .../core/transformer/transformer_config.py | 16 ++++++++++ .../inference/test_batch_dimension_utils.py | 2 +- 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 1303f61c9d2..05a3ba6d397 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -474,7 +474,7 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int def match_graph_config( real_batch_dim: InferenceBatchDimensions, cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], - cuda_graph_mixed_prefill_count: int, + smallest_non_decode_cuda_graph_size: int, strict: bool = False, decode_only_cuda_graphs: bool = False, explicit_chunked_prefill: bool = False, @@ -509,7 +509,7 @@ def match_graph_config( decode_only_cuda_graphs=decode_only_cuda_graphs, explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, - cuda_graph_mixed_prefill_count=cuda_graph_mixed_prefill_count, + cuda_graph_mixed_prefill_count=smallest_non_decode_cuda_graph_size, ) if adjusted_batch_dim is None: diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e9e033c9eee..68afb52635f 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -1340,14 +1340,19 @@ def initialize_attention_state( best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, + smallest_non_decode_cuda_graph_size=min( + self.cuda_graph_mixed_prefill_count, self.max_requests + ), strict=self.is_hybrid_model, decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, - cuda_graph_mixed_prefill_count=self.cuda_graph_mixed_prefill_count, ) self._using_cuda_graph_this_step = best_graph is not None + if construct_graph_dimensions is not None: + assert self._using_cuda_graph_this_step + if is_expert_parallel_dummy_cuda_graph_step and not self.using_cuda_graph_this_step(): # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config # could not find a compatible cuda graph for the dummy forward step. diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 461bd8e01ea..cba1bb51f4d 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -952,6 +952,29 @@ def __init__( self.is_inference_cuda_graphed_iteration = False + # FlashInfer's cutlass_fused_moe expects gated weights in [gate, activation] + # order, but TE stores them as [activation, gate]. Until FlashInfer supports + # TE's weight ordering, the FlashInfer path is only available for non-gated + # activations (e.g. squared_relu). + self._flashinfer_available = HAVE_FLASHINFER and not config.gated_linear_unit + if self._flashinfer_available: + self._flashinfer_activation_type = self._resolve_flashinfer_activation_type() + + def _resolve_flashinfer_activation_type(self): + """Map megatron activation config to FlashInfer ActivationType.""" + func = self.config.activation_func + if func == F.silu: + return ActivationType.Silu + elif func == F.gelu: + return ActivationType.Gelu + elif func == F.relu: + return ActivationType.Relu + elif func == squared_relu: + return ActivationType.Relu2 + raise ValueError( + f"No FlashInfer ActivationType mapping for activation_func={func}" + ) + def set_is_inference_cuda_graphed_iteration(self, set_to: bool): """Toggle CUDA-graphed iteration mode.""" self.is_inference_cuda_graphed_iteration = set_to @@ -1008,15 +1031,16 @@ def _build_concatenated_weights(self): def _flashinfer_forward(self, hidden_states, routing_map, probs): """FlashInfer fused MoE kernel for CUDA-graphed inference iterations.""" assert HAVE_FLASHINFER, "flashinfer-python is required for FlashInfer forward path." + assert probs.dtype == torch.float32, "FlashInfer forward path requires fp32 probabilities." output = fused_moe.cutlass_fused_moe( hidden_states, - routing_map.to(torch.int), - probs.float(), + routing_map.int(), + probs, self._fc1_weight, self._fc2_weight, hidden_states.dtype, quant_scales=None, - activation_type=ActivationType.Relu2, + activation_type=self._flashinfer_activation_type, ep_size=self.ep_group.size(), ep_rank=self.ep_group.rank(), )[0] @@ -1133,7 +1157,7 @@ def forward( if self.training: return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) - elif self.is_inference_cuda_graphed_iteration: + elif self.is_inference_cuda_graphed_iteration and self._flashinfer_available: return self._flashinfer_forward( permuted_local_hidden_states, tokens_per_expert, permuted_probs ) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index b29abefd4ec..c850985a03a 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1103,6 +1103,22 @@ def __post_init__(self): raise ValueError( "Inference-optimized MoE layers do not support padded routing map for quantization." ) + if self.num_moe_experts is not None and self.moe_router_dtype != "fp32": + raise ValueError( + "Inference-optimized MoE requires --moe-router-dtype=fp32 " + "to avoid costly dtype conversions during decode." + ) + if ( + self.num_moe_experts is not None + and self.gated_linear_unit + and self.cuda_graph_impl != "none" + ): + raise ValueError( + "Inference-optimized MoE does not yet support CUDA graphs with gated " + "linear units (SwiGLU/GeGLU) due to differences in weight layouts " + "between the FlashInfer kernel and mcore. Either disable CUDA graphs " + "(--cuda-graph-impl=none) or use a non-gated activation (e.g. squared_relu)." + ) if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError("num_moe_experts must be non-negative.") diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index d67c390068a..5bcb29a0f24 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -50,7 +50,7 @@ def _match( decode_only_cuda_graphs=decode_only, explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, - cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + smallest_non_decode_cuda_graph_size=MIXED_PREFILL_COUNT, ) From 76d7c834b3a3c7db028fd8d7a32ddef60c19170f Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 16:22:21 -0800 Subject: [PATCH 69/92] kill switch for torch grouped gemm --- megatron/core/transformer/moe/experts.py | 13 ++++++++++++- megatron/core/transformer/transformer_config.py | 9 +++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index cba1bb51f4d..62d64de1e84 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -43,6 +43,7 @@ apply_swiglu_sharded_factory, ) from megatron.core.transformer.module import MegatronModule +from megatron.core.utils import is_torch_min_version from megatron.core.transformer.moe import grouped_gemm_util as gg from megatron.core.transformer.moe.moe_utils import ( ProcessGroupCollection, @@ -952,6 +953,13 @@ def __init__( self.is_inference_cuda_graphed_iteration = False + # torch._grouped_mm requires PyTorch >= 2.10 + self._torch_grouped_mm_available = ( + is_torch_min_version("2.10") + and hasattr(torch, '_grouped_mm') + and not config.inference_disable_torch_grouped_mm + ) + # FlashInfer's cutlass_fused_moe expects gated weights in [gate, activation] # order, but TE stores them as [activation, gate]. Until FlashInfer supports # TE's weight ordering, the FlashInfer path is only available for non-gated @@ -1162,8 +1170,11 @@ def forward( permuted_local_hidden_states, tokens_per_expert, permuted_probs ) - else: + elif self._torch_grouped_mm_available: return self._torch_grouped_mm_forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) + + else: + return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index c850985a03a..bfc0bb87767 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -861,6 +861,9 @@ class TransformerConfig(ModelParallelConfig): inference_disable_triton_nvls_kernels: bool = False """ If true, disables the use of Triton NVLS kernels during inference. """ + inference_disable_torch_grouped_mm: bool = False + """ If true, disables torch._grouped_mm in InferenceGroupedMLP, falling back to TE GroupedGEMM. """ + mrope_section: Optional[List[int]] = None """ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. """ @@ -2074,6 +2077,12 @@ def __post_init__(self): "for inference_optimized transformer implementation." ) + if self.inference_disable_torch_grouped_mm: + assert self.transformer_impl == "inference_optimized", ( + "inference_disable_torch_grouped_mm is only supported " + "for inference_optimized transformer implementation." + ) + if self.batch_invariant_mode: assert ( self.attention_backend == AttnBackend.flash From 5e77a87de631a5698d3fb0120dcdd98cb43f61b5 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 16:40:57 -0800 Subject: [PATCH 70/92] change name of dispatcher --- megatron/core/transformer/moe/moe_layer.py | 6 ++--- .../moe/token_dispatcher_inference.py | 22 ++++++++++--------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 2df7cb9d618..598e2387a13 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -30,7 +30,7 @@ from megatron.core.typed_torch import apply_module from megatron.core.utils import internal_api from megatron.core.transformer.moe.token_dispatcher_inference import ( - InferenceAllGatherTokenDispatcher, + InferenceCUDAGraphTokenDispatcher, ) try: import flashinfer @@ -283,7 +283,7 @@ def _setup_inference_mode(self, pg_collection): """Set up inference-optimized token dispatcher and state. Called from __init__ when config.transformer_impl == "inference_optimized". - Creates an InferenceAllGatherTokenDispatcher alongside the standard dispatcher, + Creates an InferenceCUDAGraphTokenDispatcher alongside the standard dispatcher, which is swapped in during CUDA-graphed forward passes. """ @@ -292,7 +292,7 @@ def _setup_inference_mode(self, pg_collection): f"got '{self.config.moe_token_dispatcher_type}'" ) self.is_inference_cuda_graphed_iteration = False - self._inference_token_dispatcher = InferenceAllGatherTokenDispatcher( + self._inference_token_dispatcher = InferenceCUDAGraphTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config, diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 479b265c6eb..8fd153c4677 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -1,10 +1,11 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. """ -Inference-optimized AlltoAll Token Dispatcher with GPU-resident metadata. +CUDA-graph-compatible token dispatcher for inference. -This implementation keeps tokens_per_expert GPU-resident to enable use of -torch._grouped_mm without host synchronization. +This dispatcher is only used during CUDA-graphed inference iterations. It replaces +AlltoAll with AllGather/ReduceScatter for token exchange, keeping all metadata +GPU-resident to avoid host synchronizations that would break CUDA graph capture. Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. @@ -27,17 +28,18 @@ multimem_reduce_scatter, ) -class InferenceAllGatherTokenDispatcher(MoEAllGatherTokenDispatcher): +class InferenceCUDAGraphTokenDispatcher(MoEAllGatherTokenDispatcher): """ - Inference-optimized AllGather token dispatcher. + CUDA-graph-compatible AllGather token dispatcher for inference. - This dispatcher uses AllGather instead of AlltoAll for token exchange, - which can be simpler and more efficient for certain configurations. + Only used during CUDA-graphed inference iterations. Swapped in by + MoELayer.set_is_inference_cuda_graphed_iteration() before graph capture + and swapped out after. Key features: - - Simpler communication pattern (AllGather vs AlltoAll) - - GPU-resident metadata for CUDA graph compatibility - - Assumes tp_size == 1 (no tensor parallelism within experts) + - AllGather/ReduceScatter instead of AlltoAll for CUDA graph compatibility + - GPU-resident metadata (no host synchronization) + - NVLS collectives on Hopper+ with automatic NCCL fallback """ def __init__( From d35b58f6042de882f1cf52490a4c9e33ce670271 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Thu, 26 Feb 2026 16:45:05 -0800 Subject: [PATCH 71/92] remove bias act function duplication --- megatron/core/transformer/moe/experts.py | 181 ++++++++--------------- 1 file changed, 62 insertions(+), 119 deletions(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 62d64de1e84..c9ee9438a9e 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -729,6 +729,64 @@ def _apply_bias(intermediate_parallel, bias_parallel, tokens_per_expert, permute .to(intermediate_parallel.dtype) ) + def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): + if self.config.use_te_activation_func: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + if permuted_probs is not None: + original_dtype = intermediate_parallel.dtype + intermediate_parallel = intermediate_parallel * permuted_probs + intermediate_parallel = intermediate_parallel.to(original_dtype) + elif self.config.bias_activation_fusion: + if self.activation_func == F.silu and self.config.gated_linear_unit: + # dtype is handled inside the fused kernel + intermediate_parallel = weighted_bias_swiglu_impl( + intermediate_parallel, + bias_parallel, + permuted_probs, + self.config.activation_func_fp8_input_store, + ) + elif self.activation_func == quick_gelu and self.config.gated_linear_unit: + intermediate_parallel = weighted_bias_quick_geglu_impl( + intermediate_parallel, + bias_parallel, + permuted_probs, + self.config.activation_func_fp8_input_store, + self.config.glu_linear_offset, + self.config.activation_func_clamp_value, + ) + else: + raise ValueError( + "Only support fusion of swiglu and quick_gelu in TEGroupedMLP." + ) + elif ( + self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu + ): + assert bias_parallel is None + intermediate_parallel = weighted_squared_relu_impl( + intermediate_parallel, permuted_probs + ) + else: + if self.config.gated_linear_unit: + + def glu(x): + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if (val := self.config.activation_func_clamp_value) is not None: + x_glu = x_glu.clamp(min=None, max=val) + x_linear = x_linear.clamp(min=-val, max=val) + return self.config.activation_func(x_glu) * ( + x_linear + self.config.glu_linear_offset + ) + + intermediate_parallel = glu(intermediate_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel) + original_dtype = intermediate_parallel.dtype + intermediate_parallel = intermediate_parallel * permuted_probs + intermediate_parallel = intermediate_parallel.to(original_dtype) + return intermediate_parallel + def forward( self, permuted_local_hidden_states: torch.Tensor, @@ -781,74 +839,17 @@ def forward( forced_released_tensors=[permuted_local_hidden_states], ) - def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): - if self.config.use_te_activation_func: - if bias_parallel is not None: - intermediate_parallel = intermediate_parallel + bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) - if permuted_probs is not None: - original_dtype = intermediate_parallel.dtype - intermediate_parallel = intermediate_parallel * permuted_probs - intermediate_parallel = intermediate_parallel.to(original_dtype) - elif self.config.bias_activation_fusion: - if self.activation_func == F.silu and self.config.gated_linear_unit: - # dtype is handled inside the fused kernel - intermediate_parallel = weighted_bias_swiglu_impl( - intermediate_parallel, - bias_parallel, - permuted_probs, - self.config.activation_func_fp8_input_store, - ) - elif self.activation_func == quick_gelu and self.config.gated_linear_unit: - intermediate_parallel = weighted_bias_quick_geglu_impl( - intermediate_parallel, - bias_parallel, - permuted_probs, - self.config.activation_func_fp8_input_store, - self.config.glu_linear_offset, - self.config.activation_func_clamp_value, - ) - else: - raise ValueError( - "Only support fusion of swiglu and quick_gelu in TEGroupedMLP." - ) - elif ( - self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu - ): - assert bias_parallel is None - intermediate_parallel = weighted_squared_relu_impl( - intermediate_parallel, permuted_probs - ) - else: - if self.config.gated_linear_unit: - - def glu(x): - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - if (val := self.config.activation_func_clamp_value) is not None: - x_glu = x_glu.clamp(min=None, max=val) - x_linear = x_linear.clamp(min=-val, max=val) - return self.config.activation_func(x_glu) * ( - x_linear + self.config.glu_linear_offset - ) - - intermediate_parallel = glu(intermediate_parallel) - else: - intermediate_parallel = self.activation_func(intermediate_parallel) - original_dtype = intermediate_parallel.dtype - intermediate_parallel = intermediate_parallel * permuted_probs - intermediate_parallel = intermediate_parallel.to(original_dtype) - return intermediate_parallel + if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: bias_act_output = self.activation_checkpoint.checkpoint( - bias_act_func, fc1_output, bias_parallel, permuted_probs + self.bias_act_func, fc1_output, bias_parallel, permuted_probs ) else: with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: - bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) - + bias_act_output = self.bias_act_func(fc1_output, bias_parallel, permuted_probs) output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) @@ -1069,64 +1070,6 @@ def _torch_grouped_mm_forward(self, permuted_local_hidden_states, tokens_per_exp permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype) permuted_probs = torch.ones_like(permuted_probs) - def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): - if self.config.use_te_activation_func: - if bias_parallel is not None: - intermediate_parallel = intermediate_parallel + bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) - if permuted_probs is not None: - original_dtype = intermediate_parallel.dtype - intermediate_parallel = intermediate_parallel * permuted_probs - intermediate_parallel = intermediate_parallel.to(original_dtype) - elif self.config.bias_activation_fusion: - if self.activation_func == F.silu and self.config.gated_linear_unit: - # dtype is handled inside the fused kernel - intermediate_parallel = weighted_bias_swiglu_impl( - intermediate_parallel, - bias_parallel, - permuted_probs, - self.config.activation_func_fp8_input_store, - ) - elif self.activation_func == quick_gelu and self.config.gated_linear_unit: - intermediate_parallel = weighted_bias_quick_geglu_impl( - intermediate_parallel, - bias_parallel, - permuted_probs, - self.config.activation_func_fp8_input_store, - self.config.glu_linear_offset, - self.config.activation_func_clamp_value, - ) - else: - raise ValueError( - "Only support fusion of swiglu and quick_gelu in TEGroupedMLP." - ) - elif ( - self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu - ): - assert bias_parallel is None - intermediate_parallel = weighted_squared_relu_impl( - intermediate_parallel, permuted_probs - ) - else: - if self.config.gated_linear_unit: - - def glu(x): - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - if (val := self.config.activation_func_clamp_value) is not None: - x_glu = x_glu.clamp(min=None, max=val) - x_linear = x_linear.clamp(min=-val, max=val) - return self.config.activation_func(x_glu) * ( - x_linear + self.config.glu_linear_offset - ) - - intermediate_parallel = glu(intermediate_parallel) - else: - intermediate_parallel = self.activation_func(intermediate_parallel) - original_dtype = intermediate_parallel.dtype - intermediate_parallel = intermediate_parallel * permuted_probs - intermediate_parallel = intermediate_parallel.to(original_dtype) - return intermediate_parallel - if permuted_local_hidden_states.nelement() != 0: # Use pre-concatenated weights (built during init/load) @@ -1141,7 +1084,7 @@ def glu(x): # Activation with routing probabilities # intermediate_parallel = self._activation_func_with_probs(fc1_output, permuted_probs) - bias_act_output = bias_act_func(fc1_output, None, permuted_probs) + bias_act_output = self.bias_act_func(fc1_output, None, permuted_probs) # FC2: [total_tokens, ffn_hidden] @ [num_experts, hidden, ffn_hidden] -> [total_tokens, hidden] fc2_output = torch._grouped_mm(bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs) From d22a5e17d40295794539d8364dc86626f78a49a6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 10:55:31 -0800 Subject: [PATCH 72/92] change the name of cuda graph mixed prefill count to cuda graph mixed prefill request count --- .../core/inference/batch_dimensions_utils.py | 20 +++++++++---------- .../inference/contexts/dynamic_context.py | 11 +++++----- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 05a3ba6d397..13b6e41f6c6 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -133,7 +133,7 @@ def adjust_batch_dims_for_expert_parallelism( strict: bool, decode_only_cuda_graphs: bool, explicit_chunked_prefill: bool, - cuda_graph_mixed_prefill_count: int, + smallest_non_decode_cuda_graph_size: int, ep_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Optional["InferenceBatchDimensions"]: """Adjusted cuda graph batch dimensions for expert parallelism. @@ -203,7 +203,7 @@ def adjust_batch_dims_for_expert_parallelism( # graph while prefill ranks match a coarser mixed graph, which would # produce inconsistent token counts across EP ranks. if is_any_ep_rank_in_non_decode and not strict: - adjusted_token_count = max(adjusted_token_count, cuda_graph_mixed_prefill_count) + adjusted_token_count = max(adjusted_token_count, smallest_non_decode_cuda_graph_size) adjusted_batch_dim = InferenceBatchDimensions( token_count=adjusted_token_count, @@ -303,7 +303,7 @@ def generate_cuda_graph_batch_dimensions_list( tp_size: int, num_cuda_graphs: Optional[int], cuda_graph_max_tokens: int, - cuda_graph_mixed_prefill_count: Optional[int], + cuda_graph_mixed_prefill_request_count: Optional[int], max_requests: int, max_tokens: int, max_sequence_length: int, @@ -339,7 +339,7 @@ def generate_cuda_graph_batch_dimensions_list( tp_size: Tensor parallel size num_cuda_graphs: Number of CUDA graphs to generate cuda_graph_max_tokens: Maximum tokens for CUDA graphs - cuda_graph_mixed_prefill_count: Number of mixed prefill requests for CUDA graphs + cuda_graph_mixed_prefill_request_count: Number of mixed prefill requests for CUDA graphs max_requests: Maximum number of requests max_tokens: Maximum total tokens max_sequence_length: Maximum sequence length @@ -409,8 +409,8 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int if num_cuda_graphs is None: cuda_graph_batch_dimensions_list = [] elif ( - not cuda_graph_mixed_prefill_count - or cuda_graph_mixed_prefill_count <= 0 + not cuda_graph_mixed_prefill_request_count + or cuda_graph_mixed_prefill_request_count <= 0 or not use_cuda_graphs_for_non_decode_steps ): # decode only # Use decode-specific token counts for decode-only graphs @@ -426,14 +426,14 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int for size in cuda_graph_prefill_token_counts: add_if_valid( token_count=size, - prefill_req_count=min(cuda_graph_mixed_prefill_count, max_requests), + prefill_req_count=min(cuda_graph_mixed_prefill_request_count, max_requests), decode_req_count=min(size, max_requests) - - min(cuda_graph_mixed_prefill_count, max_requests), + - min(cuda_graph_mixed_prefill_request_count, max_requests), ) # We need to ensure the prefill requests are shorter than the max sequence length, # considering the one decode token is used for prefill request construction prefill_only_minimal_num = max( - cuda_graph_mixed_prefill_count, + cuda_graph_mixed_prefill_request_count, math.ceil(size / max(1, max_sequence_length - 1)), ) if prefill_only_minimal_num < max_requests: @@ -509,7 +509,7 @@ def match_graph_config( decode_only_cuda_graphs=decode_only_cuda_graphs, explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, - cuda_graph_mixed_prefill_count=smallest_non_decode_cuda_graph_size, + smallest_non_decode_cuda_graph_size=smallest_non_decode_cuda_graph_size, ) if adjusted_batch_dim is None: diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 68afb52635f..83d7015f572 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -535,7 +535,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC tp_size=tp_size, num_cuda_graphs=inference_config.num_cuda_graphs, cuda_graph_max_tokens=self.max_requests, - cuda_graph_mixed_prefill_count=inference_config.cuda_graph_mixed_prefill_count, + cuda_graph_mixed_prefill_request_count=inference_config.cuda_graph_mixed_prefill_count, max_requests=self.max_requests, max_tokens=self.max_tokens, max_sequence_length=self.max_sequence_length, @@ -543,7 +543,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - self.cuda_graph_mixed_prefill_count = inference_config.cuda_graph_mixed_prefill_count + self.smallest_non_decode_cuda_graph_size = min( + inference_config.cuda_graph_mixed_prefill_count, self.max_requests + ), + self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -1340,9 +1343,7 @@ def initialize_attention_state( best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, - smallest_non_decode_cuda_graph_size=min( - self.cuda_graph_mixed_prefill_count, self.max_requests - ), + smallest_non_decode_cuda_graph_size=self.smallest_non_decode_cuda_graph_size, strict=self.is_hybrid_model, decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, From 1670f7e837a6eee8f57bb8c184827e3f90fafd84 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 11:06:40 -0800 Subject: [PATCH 73/92] cleanup asserts, disable fused tp kernel for moe --- .../core/transformer/transformer_config.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index bfc0bb87767..b6f3b3468fe 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1090,14 +1090,12 @@ def __post_init__(self): if self.expert_model_parallel_size > 1 and self.num_moe_experts is None: raise ValueError("num_moe_experts must be non None to use expert-parallel.") - if self.transformer_impl == "inference_optimized" and ( - self.expert_tensor_parallel_size > 1 - ): - raise ValueError( - "Inference-optimized MoE layers does not support expert tensor parallelism." - ) - - if self.transformer_impl == "inference_optimized": + + if self.transformer_impl == "inference_optimized" and self.num_moe_experts is not None: + if self.expert_tensor_parallel_size > 1: + raise ValueError( + "Inference-optimized MoE layers does not support expert tensor parallelism." + ) if self.moe_expert_capacity_factor is not None: raise ValueError( "Inference-optimized MoE layers only support dropless MoE " @@ -1106,14 +1104,13 @@ def __post_init__(self): raise ValueError( "Inference-optimized MoE layers do not support padded routing map for quantization." ) - if self.num_moe_experts is not None and self.moe_router_dtype != "fp32": + if self.moe_router_dtype != "fp32": raise ValueError( "Inference-optimized MoE requires --moe-router-dtype=fp32 " "to avoid costly dtype conversions during decode." ) if ( - self.num_moe_experts is not None - and self.gated_linear_unit + self.gated_linear_unit and self.cuda_graph_impl != "none" ): raise ValueError( @@ -2070,6 +2067,7 @@ def __post_init__(self): "inference_fuse_tp_communication is only supported " "for inference_optimized transformer implementation." ) + assert self.num_moe_experts is None, "--inference-fuse-tp-communication is not supported for MoE models." if self.inference_disable_triton_nvls_kernels: assert self.transformer_impl == "inference_optimized", ( From cc3e18fbdfdd858db29bfea3313d18fc80e414be Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 11:09:28 -0800 Subject: [PATCH 74/92] format --- .../torch_symm_triton/__init__.py | 2 +- .../torch_symm_triton/collectives.py | 153 ++++++++++++------ .../communication/torch_symm_triton/utils.py | 5 +- .../inference/contexts/dynamic_context.py | 10 +- .../core/inference/engines/dynamic_engine.py | 6 +- .../text_generation_controller.py | 2 +- megatron/core/inference/utils.py | 9 +- megatron/core/models/backends.py | 17 +- megatron/core/models/gpt/gpt_layer_specs.py | 2 +- megatron/core/models/gpt/moe_module_specs.py | 14 +- .../core/models/mamba/mamba_layer_specs.py | 5 +- megatron/core/parallel_state.py | 11 +- megatron/core/transformer/moe/experts.py | 55 +++---- megatron/core/transformer/moe/moe_layer.py | 18 ++- megatron/core/transformer/moe/router.py | 20 ++- .../moe/token_dispatcher_inference.py | 79 +++++---- .../core/transformer/transformer_config.py | 14 +- megatron/core/utils.py | 2 +- 18 files changed, 253 insertions(+), 171 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index 9654e8bc67c..967dc8329f1 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter -from .utils import are_tensors_nvls_eligible, is_device_nvls_capable from .fused_collectives import fused_multimem_rs_add_norm_ag +from .utils import are_tensors_nvls_eligible, is_device_nvls_capable diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 3ba8598467f..1289fd54d60 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -23,12 +23,15 @@ from .barrier import symm_mem_sync from .multimem_asm import ld_128, st_128 -from .utils import get_flat_tid, are_tensors_nvls_eligible, sync_threads +from .utils import are_tensors_nvls_eligible, get_flat_tid, sync_threads # ── Triton kernels ───────────────────────────────────────────────────────── + @triton.jit -def _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE): +def _ag_phase( + local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE +): """ Core all-gather phase: load from local memory, multicast-store to symmetric buffer. This is the building block for both single-tensor and fused multi-tensor all-gathers. @@ -81,21 +84,33 @@ def _multimem_all_gather_kernel( WORLD_SIZE: tl.constexpr, ): """Single-tensor multicast all-gather kernel.""" - _ag_phase(local_ptr, multicast_ptr, byte_offset, numel, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + _ag_phase( + local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE + ) sync_threads() - symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, - hasPreviousMemAccess=True, hasSubsequentMemAccess=True) + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) @triton.jit def _multimem_all_gather_3_kernel( - local_ptr_0, local_ptr_1, local_ptr_2, + local_ptr_0, + local_ptr_1, + local_ptr_2, multicast_ptr, signal_pad_ptrs, - numel_0, byte_offset_0, - numel_1, byte_offset_1, - numel_2, byte_offset_2, + numel_0, + byte_offset_0, + numel_1, + byte_offset_1, + numel_2, + byte_offset_2, BLOCK_SIZE: tl.constexpr, NUMEL_PER_THREAD: tl.constexpr, RANK: tl.constexpr, @@ -106,15 +121,46 @@ def _multimem_all_gather_3_kernel( then synchronizes once, eliminating 2 kernel launches and 2 barriers compared to three separate multimem_all_gather calls. """ - _ag_phase(local_ptr_0, multicast_ptr, byte_offset_0, numel_0, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - _ag_phase(local_ptr_1, multicast_ptr, byte_offset_1, numel_1, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) - _ag_phase(local_ptr_2, multicast_ptr, byte_offset_2, numel_2, - BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE) + _ag_phase( + local_ptr_0, + multicast_ptr, + byte_offset_0, + numel_0, + BLOCK_SIZE, + NUMEL_PER_THREAD, + RANK, + WORLD_SIZE, + ) + _ag_phase( + local_ptr_1, + multicast_ptr, + byte_offset_1, + numel_1, + BLOCK_SIZE, + NUMEL_PER_THREAD, + RANK, + WORLD_SIZE, + ) + _ag_phase( + local_ptr_2, + multicast_ptr, + byte_offset_2, + numel_2, + BLOCK_SIZE, + NUMEL_PER_THREAD, + RANK, + WORLD_SIZE, + ) sync_threads() - symm_mem_sync(signal_pad_ptrs, None, RANK, WORLD_SIZE, - hasPreviousMemAccess=True, hasSubsequentMemAccess=True) + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, + ) + @triton.jit def _multimem_reduce_scatter_kernel( @@ -162,13 +208,10 @@ def _multimem_reduce_scatter_kernel( block_start += tl.num_programs(axis=0) * BLOCK_SIZE + # ── Python wrappers ───────────────────────────────────────────────────────── -_DEFAULT_KERNEL_CONFIG = { - "max_num_blocks": 128, - "num_warps": 32, - "BLOCK_SIZE": 1024, -} +_DEFAULT_KERNEL_CONFIG = {"max_num_blocks": 128, "num_warps": 32, "BLOCK_SIZE": 1024} def _kernel_launch_config(element_size: int, max_numel: int, world_size: int, **kwargs): @@ -202,13 +245,16 @@ def multimem_all_gather( Input tensor can be a regular torch tensor. """ assert HAVE_TRITON, "Triton is required for multimem all-gather." - assert are_tensors_nvls_eligible(input_tensor), "Input tensor must be 16-byte divisible on Hopper+ for NVLS." - assert output_tensor.numel() % input_tensor.numel() == 0 and \ - output_tensor.numel() // input_tensor.numel() == symm_mem_hdl.world_size, \ - "Output numel must be exactly world_size * input numel for all-gather." + assert are_tensors_nvls_eligible( + input_tensor + ), "Input tensor must be 16-byte divisible on Hopper+ for NVLS." + assert ( + output_tensor.numel() % input_tensor.numel() == 0 + and output_tensor.numel() // input_tensor.numel() == symm_mem_hdl.world_size + ), "Output numel must be exactly world_size * input numel for all-gather." numel_per_thread, num_blocks, config = _kernel_launch_config( - input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs, + input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs ) _multimem_all_gather_kernel[(num_blocks, 1, 1)]( input_tensor.data_ptr(), @@ -227,9 +273,15 @@ def multimem_all_gather( def multimem_all_gather_fused( - output_0: torch.Tensor, input_0: torch.Tensor, byte_offset_0: int, - output_1: torch.Tensor, input_1: torch.Tensor, byte_offset_1: int, - output_2: torch.Tensor, input_2: torch.Tensor, byte_offset_2: int, + output_0: torch.Tensor, + input_0: torch.Tensor, + byte_offset_0: int, + output_1: torch.Tensor, + input_1: torch.Tensor, + byte_offset_1: int, + output_2: torch.Tensor, + input_2: torch.Tensor, + byte_offset_2: int, symm_mem_hdl: _SymmetricMemory, **kwargs, ) -> None: @@ -240,25 +292,31 @@ def multimem_all_gather_fused( All tensors must share the same symmetric memory handle. """ assert HAVE_TRITON, "Triton is required for multimem all-gather." - assert are_tensors_nvls_eligible(input_0, input_1, input_2), \ - "All input tensors must be 16-byte divisible on Hopper+ for NVLS." + assert are_tensors_nvls_eligible( + input_0, input_1, input_2 + ), "All input tensors must be 16-byte divisible on Hopper+ for NVLS." for inp, out in [(input_0, output_0), (input_1, output_1), (input_2, output_2)]: - assert out.numel() % inp.numel() == 0 and \ - out.numel() // inp.numel() == symm_mem_hdl.world_size, \ - "Output numel must be exactly world_size * input numel for all-gather." + assert ( + out.numel() % inp.numel() == 0 and out.numel() // inp.numel() == symm_mem_hdl.world_size + ), "Output numel must be exactly world_size * input numel for all-gather." max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) numel_per_thread, num_blocks, config = _kernel_launch_config( - input_0.element_size(), max_numel, symm_mem_hdl.world_size, **kwargs, + input_0.element_size(), max_numel, symm_mem_hdl.world_size, **kwargs ) _multimem_all_gather_3_kernel[(num_blocks, 1, 1)]( - input_0.data_ptr(), input_1.data_ptr(), input_2.data_ptr(), + input_0.data_ptr(), + input_1.data_ptr(), + input_2.data_ptr(), symm_mem_hdl.multicast_ptr, symm_mem_hdl.signal_pad_ptrs_dev, - numel_0=output_0.numel(), byte_offset_0=byte_offset_0, - numel_1=output_1.numel(), byte_offset_1=byte_offset_1, - numel_2=output_2.numel(), byte_offset_2=byte_offset_2, + numel_0=output_0.numel(), + byte_offset_0=byte_offset_0, + numel_1=output_1.numel(), + byte_offset_1=byte_offset_1, + numel_2=output_2.numel(), + byte_offset_2=byte_offset_2, BLOCK_SIZE=config["BLOCK_SIZE"], NUMEL_PER_THREAD=numel_per_thread, RANK=symm_mem_hdl.rank, @@ -281,13 +339,16 @@ def multimem_reduce_scatter( assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - assert are_tensors_nvls_eligible(output_tensor), "Output tensor must be 16-byte divisible on Hopper+ for NVLS." - assert input_tensor.numel() % output_tensor.numel() == 0 and \ - input_tensor.numel() // output_tensor.numel() == symm_mem_hdl.world_size, \ - "Input numel must be exactly world_size * output numel for reduce-scatter." + assert are_tensors_nvls_eligible( + output_tensor + ), "Output tensor must be 16-byte divisible on Hopper+ for NVLS." + assert ( + input_tensor.numel() % output_tensor.numel() == 0 + and input_tensor.numel() // output_tensor.numel() == symm_mem_hdl.world_size + ), "Input numel must be exactly world_size * output numel for reduce-scatter." numel_per_thread, num_blocks, config = _kernel_launch_config( - output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs, + output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs ) _multimem_reduce_scatter_kernel[(num_blocks, 1, 1)]( output_tensor.data_ptr(), diff --git a/megatron/core/inference/communication/torch_symm_triton/utils.py b/megatron/core/inference/communication/torch_symm_triton/utils.py index 6e1eee74da4..5ace510f2b8 100644 --- a/megatron/core/inference/communication/torch_symm_triton/utils.py +++ b/megatron/core/inference/communication/torch_symm_triton/utils.py @@ -2,9 +2,10 @@ # Adapted from: https://github.com/meta-pytorch/kraken.git -import torch from unittest.mock import MagicMock +import torch + from megatron.core.utils import null_decorator try: @@ -16,7 +17,6 @@ triton.jit = null_decorator - def is_device_nvls_capable(device: torch.device) -> bool: """Check if the device supports NVLS (multicast) collectives. Requires CUDA Hopper+ (SM >= 9).""" return device.type == "cuda" and torch.cuda.get_device_properties(device).major >= 9 @@ -37,7 +37,6 @@ def are_tensors_nvls_eligible(*tensors: torch.Tensor) -> bool: ) - @triton.jit def get_tid(): """ diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 83d7015f572..5f81b79eaf9 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -543,10 +543,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - self.smallest_non_decode_cuda_graph_size = min( - inference_config.cuda_graph_mixed_prefill_count, self.max_requests - ), - + self.smallest_non_decode_cuda_graph_size = ( + min(inference_config.cuda_graph_mixed_prefill_count, self.max_requests), + ) + self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -1343,7 +1343,7 @@ def initialize_attention_state( best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, - smallest_non_decode_cuda_graph_size=self.smallest_non_decode_cuda_graph_size, + smallest_non_decode_cuda_graph_size=self.smallest_non_decode_cuda_graph_size, strict=self.is_hybrid_model, decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index c728e07bae6..3f6f9e86e44 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -41,7 +41,11 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.inference.utils import Counter, await_process_call, set_is_inference_cuda_graphed_iteration_for_ep_inference +from megatron.core.inference.utils import ( + Counter, + await_process_call, + set_is_inference_cuda_graphed_iteration_for_ep_inference, +) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.enums import CudaGraphScope diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index cc75a7af856..7e50f58e3e6 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -530,7 +530,7 @@ def _dynamic_step_context_init( moe_pad_experts_for_cuda_graph_inference = ( self.model_config.moe_pad_experts_for_cuda_graph_inference ) - is_inference_optimized = self.model_config.transformer_impl == "inference_optimized" + is_inference_optimized = self.model_config.transformer_impl == "inference_optimized" if is_inference_optimized: assert not moe_pad_experts_for_cuda_graph_inference, ( "moe_pad_experts_for_cuda_graph_inference cannot be True when " diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index 42ac9577868..5810dda0b26 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -131,11 +131,12 @@ def set_decode_expert_padding(model, set_to: bool = False, capacity_factor: int router.config.moe_expert_capacity_factor = capacity_factor router.config.moe_pad_expert_input_to_capacity = bool(set_to) + def set_is_inference_cuda_graphed_iteration_for_ep_inference(model, set_to: bool): """ - Toggle CUDA graph compatibility for expert parallel inference. + Toggle CUDA graph compatibility for expert parallel inference. This sets a boolean flag in all MoELayers to indicate whether - the current iteration is being captured/executed in a CUDA graph. + the current iteration is being captured/executed in a CUDA graph. This allows the dispatcher to adjust its behavior for CUDA graph compatibility, Args: - set_to: Enable (True) or disable (False) CUDA graph compatibility. @@ -143,10 +144,11 @@ def set_is_inference_cuda_graphed_iteration_for_ep_inference(model, set_to: bool global moe_layer_cache if moe_layer_cache is None: _init_moe_expert_cache(model) - + for moe_layer in moe_layer_cache: moe_layer.set_is_inference_cuda_graphed_iteration(set_to) + def tensor_swap(x, src_idxs, dst_idxs): """ Swap x[src_idxs] and x[dst_idxs] @@ -228,4 +230,3 @@ def shutdown(self): else: asyncio_QueueShutDown = asyncio.QueueShutDown asyncio_Queue = asyncio.Queue - diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index c4e984005b8..ebb979772f0 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -3,19 +3,24 @@ import warnings from abc import abstractmethod -from typing import Optional, Protocol, cast, Tuple +from typing import Optional, Protocol, Tuple, cast -from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear -from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.extensions.transformer_engine import ( TEColumnParallelGroupedLinear, - TERowParallelGroupedLinear + TERowParallelGroupedLinear, ) -from megatron.core.utils import is_te_min_version +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.mlp import MLPSubmodules, TEActivationFunctionBuilder -from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLPSubmodules, InferenceGroupedMLP +from megatron.core.transformer.moe.experts import ( + GroupedMLP, + InferenceGroupedMLP, + SequentialMLP, + TEGroupedMLPSubmodules, +) from megatron.core.transformer.torch_norm import LayerNormBuilder, WrappedTorchNorm from megatron.core.typed_torch import not_none +from megatron.core.utils import is_te_min_version try: import apex # pylint: disable=unused-import diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 9e711a92fda..aae2d5f3e81 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -573,7 +573,7 @@ def get_gpt_decoder_layer_specs( qk_l2_norm=qk_l2_norm, num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, - moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, ) else: layer_norm_impl = LNImpl diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index e076ed6a5bf..4b0d5640b46 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -3,12 +3,16 @@ from typing import Optional from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider -from megatron.core.models.backends import BackendSpecProvider, LocalSpecProvider, InferenceSpecProvider +from megatron.core.models.backends import ( + BackendSpecProvider, + InferenceSpecProvider, + LocalSpecProvider, +) from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.router import InferenceTopKRouter from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec -from megatron.core.transformer.moe.router import InferenceTopKRouter def get_moe_module_spec( @@ -39,6 +43,7 @@ def get_moe_module_spec( moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, ) + def get_moe_module_spec_for_backend( backend: BackendSpecProvider, num_experts: Optional[int] = None, @@ -78,7 +83,6 @@ def get_moe_module_spec_for_backend( return moe_module_spec - def get_inference_optimized_moe_spec() -> ModuleSpec: """MoE module spec for inference-optimized transformer impl. @@ -108,9 +112,7 @@ def get_inference_optimized_moe_spec() -> ModuleSpec: return ModuleSpec( module=MoELayer, submodules=MoESubmodules( - router=InferenceTopKRouter, - experts=experts, - shared_experts=shared_experts, + router=InferenceTopKRouter, experts=experts, shared_experts=shared_experts ), metainfo={"fuse_pre_mlp_layernorm": False}, ) diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 791b63ad2eb..044eefa730f 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -8,7 +8,10 @@ TERowParallelLinear, ) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.moe_module_specs import get_inference_optimized_moe_spec, get_moe_module_spec +from megatron.core.models.gpt.moe_module_specs import ( + get_inference_optimized_moe_spec, + get_moe_module_spec, +) from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index b571a357fad..2206eb533e3 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -2017,9 +2017,13 @@ def _set_global_memory_buffer(): def _set_global_symmetric_memory_buffer(): """Initialize global buffer.""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP - assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is None, "global symmetric memory buffer for TP is already initialized" - assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is None, "global symmetric memory buffer for EP is already initialized" + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is None + ), "global symmetric memory buffer for TP is already initialized" + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is None + ), "global symmetric memory buffer for EP is already initialized" _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = GlobalSymmetricMemoryBuffer( size_in_mb=256, # todo: set from an argument? @@ -2045,6 +2049,7 @@ def get_global_symmetric_memory_buffer_tp(): ), "global symmetric memory buffer is not initialized" return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP + def get_global_symmetric_memory_buffer_ep(): """Return the global GlobalSymmetricMemoryBuffer object""" assert ( diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index c9ee9438a9e..0d0260e2c76 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -43,7 +43,6 @@ apply_swiglu_sharded_factory, ) from megatron.core.transformer.module import MegatronModule -from megatron.core.utils import is_torch_min_version from megatron.core.transformer.moe import grouped_gemm_util as gg from megatron.core.transformer.moe.moe_utils import ( ProcessGroupCollection, @@ -56,6 +55,7 @@ sharded_state_dict_default, ) from megatron.core.typed_torch import apply_module, not_none +from megatron.core.utils import is_torch_min_version try: import transformer_engine as te # pylint: disable=unused-import @@ -757,12 +757,8 @@ def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): self.config.activation_func_clamp_value, ) else: - raise ValueError( - "Only support fusion of swiglu and quick_gelu in TEGroupedMLP." - ) - elif ( - self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu - ): + raise ValueError("Only support fusion of swiglu and quick_gelu in TEGroupedMLP.") + elif self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu: assert bias_parallel is None intermediate_parallel = weighted_squared_relu_impl( intermediate_parallel, permuted_probs @@ -839,8 +835,6 @@ def forward( forced_released_tensors=[permuted_local_hidden_states], ) - - if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: @@ -980,9 +974,7 @@ def _resolve_flashinfer_activation_type(self): return ActivationType.Relu elif func == squared_relu: return ActivationType.Relu2 - raise ValueError( - f"No FlashInfer ActivationType mapping for activation_func={func}" - ) + raise ValueError(f"No FlashInfer ActivationType mapping for activation_func={func}") def set_is_inference_cuda_graphed_iteration(self, set_to: bool): """Toggle CUDA-graphed iteration mode.""" @@ -1008,12 +1000,8 @@ def _build_concatenated_weights(self): fc2_shape = self.linear_fc2.weight0.shape # Create big contiguous tensors - _fc1_weight = torch.empty( - self.num_local_experts, *fc1_shape, device=device, dtype=dtype - ) - _fc2_weight = torch.empty( - self.num_local_experts, *fc2_shape, device=device, dtype=dtype - ) + _fc1_weight = torch.empty(self.num_local_experts, *fc1_shape, device=device, dtype=dtype) + _fc2_weight = torch.empty(self.num_local_experts, *fc2_shape, device=device, dtype=dtype) # Copy existing TE weights into big tensors, then replace with views for i in range(self.num_local_experts): @@ -1026,12 +1014,8 @@ def _build_concatenated_weights(self): delattr(self.linear_fc2, f'weight{i}') # Register views as parameters (checkpoint loads will write into big tensor) - self.linear_fc1.register_parameter( - f'weight{i}', torch.nn.Parameter(_fc1_weight[i]) - ) - self.linear_fc2.register_parameter( - f'weight{i}', torch.nn.Parameter(_fc2_weight[i]) - ) + self.linear_fc1.register_parameter(f'weight{i}', torch.nn.Parameter(_fc1_weight[i])) + self.linear_fc2.register_parameter(f'weight{i}', torch.nn.Parameter(_fc2_weight[i])) # Register big tensors as non-persistent buffers (for .to() device movement, not saved) self.register_buffer('_fc1_weight', _fc1_weight, persistent=False) @@ -1054,10 +1038,12 @@ def _flashinfer_forward(self, hidden_states, routing_map, probs): ep_rank=self.ep_group.rank(), )[0] return output, None - - def _torch_grouped_mm_forward(self, permuted_local_hidden_states, tokens_per_expert, permuted_probs): + + def _torch_grouped_mm_forward( + self, permuted_local_hidden_states, tokens_per_expert, permuted_probs + ): permuted_probs = permuted_probs.unsqueeze(-1) - #assert tokens_per_expert.is_cuda, "tokens_per_expert must be on GPU" + # assert tokens_per_expert.is_cuda, "tokens_per_expert must be on GPU" if not tokens_per_expert.is_cuda: tokens_per_expert = tokens_per_expert.to('cuda') @@ -1070,7 +1056,6 @@ def _torch_grouped_mm_forward(self, permuted_local_hidden_states, tokens_per_exp permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype) permuted_probs = torch.ones_like(permuted_probs) - if permuted_local_hidden_states.nelement() != 0: # Use pre-concatenated weights (built during init/load) # _fc1_weight shape: [num_experts, ffn_hidden * (2 if gated else 1), hidden_size] @@ -1080,14 +1065,18 @@ def _torch_grouped_mm_forward(self, permuted_local_hidden_states, tokens_per_exp offs = tokens_per_expert.cumsum(0).to(torch.int32) # FC1: [total_tokens, hidden] @ [num_experts, ffn_hidden, hidden] -> [total_tokens, ffn_hidden] - fc1_output = torch._grouped_mm(permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs) + fc1_output = torch._grouped_mm( + permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs + ) # Activation with routing probabilities # intermediate_parallel = self._activation_func_with_probs(fc1_output, permuted_probs) bias_act_output = self.bias_act_func(fc1_output, None, permuted_probs) # FC2: [total_tokens, ffn_hidden] @ [num_experts, hidden, ffn_hidden] -> [total_tokens, hidden] - fc2_output = torch._grouped_mm(bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs) + fc2_output = torch._grouped_mm( + bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs + ) else: # No tokens allocated - return empty tensor with correct shape fc2_output = permuted_local_hidden_states @@ -1114,12 +1103,12 @@ def forward( ) elif self._torch_grouped_mm_available: - return self._torch_grouped_mm_forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) + return self._torch_grouped_mm_forward( + permuted_local_hidden_states, tokens_per_expert, permuted_probs + ) else: return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) - - class SequentialMLP(MegatronModule): diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 598e2387a13..c479b808a54 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -25,15 +25,17 @@ MoEFlexTokenDispatcher, MoETokenDispatcher, ) +from megatron.core.transformer.moe.token_dispatcher_inference import ( + InferenceCUDAGraphTokenDispatcher, +) from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.typed_torch import apply_module from megatron.core.utils import internal_api -from megatron.core.transformer.moe.token_dispatcher_inference import ( - InferenceCUDAGraphTokenDispatcher, -) + try: import flashinfer + HAVE_FLASHINFER = True except ImportError: HAVE_FLASHINFER = False @@ -42,6 +44,7 @@ try: import flashinfer_cubin import flashinfer_jit_cache + HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = True except ImportError: HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = False @@ -265,7 +268,9 @@ def __init__( # Inference-optimized mode setup if config.transformer_impl == "inference_optimized": - assert HAVE_FLASHINFER, "flashinfer-python is required for inference-optimized MoE implementation." + assert ( + HAVE_FLASHINFER + ), "flashinfer-python is required for inference-optimized MoE implementation." if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: warnings.warn( "flashinfer-cubin and/or flashinfer-jit-cache not found. " @@ -273,12 +278,10 @@ def __init__( ) self._setup_inference_mode(pg_collection) - # Cudagraph tensor store for resuming the forward pass from the end of the cudagraph. self.cudagraph_tensor_store = MoECudaGraphTensorStore() self.fwd_execution_map = ["route", "expert_compute", "postprocess"] - def _setup_inference_mode(self, pg_collection): """Set up inference-optimized token dispatcher and state. @@ -286,7 +289,7 @@ def _setup_inference_mode(self, pg_collection): Creates an InferenceCUDAGraphTokenDispatcher alongside the standard dispatcher, which is swapped in during CUDA-graphed forward passes. """ - + assert self.config.moe_token_dispatcher_type == "alltoall", ( f"Inference-optimized MoE requires 'alltoall' dispatcher, " f"got '{self.config.moe_token_dispatcher_type}'" @@ -320,7 +323,6 @@ def set_is_inference_cuda_graphed_iteration(self, set_to: bool): self.token_dispatcher = self._saved_token_dispatcher self.shared_expert_overlap = self._saved_shared_expert_overlap - @maybe_skip_or_early_return_by_cudagraph("route") def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): """Compute token routing for preprocessing. diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index d21f895f229..723033a3015 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -24,6 +24,7 @@ from megatron.core.transformer.moe.router_replay import RouterReplay from megatron.core.transformer.transformer_config import TransformerConfig + class Router(ABC, MegatronModule): """Base Router class""" @@ -719,12 +720,15 @@ class InferenceTopKRouter(TopKRouter): method is @torch.compile()'d and returns dense [num_tokens, topk] tensors instead of sparse [num_tokens, num_experts] for CUDA graph compatibility. - Falls back to the parent TopKRouter.forward() for training or + Falls back to the parent TopKRouter.forward() for training or non-CUDA-graphed inference iterations. """ def __init__( - self, config: TransformerConfig, pg_collection: Optional[ProcessGroupCollection] = None, is_mtp_layer: bool = False, + self, + config: TransformerConfig, + pg_collection: Optional[ProcessGroupCollection] = None, + is_mtp_layer: bool = False, ) -> None: """Initialize the specialized inference top-k router. @@ -751,11 +755,11 @@ def set_is_inference_cuda_graphed_iteration(self, set_to: bool): @torch.compile() def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): logits = self.gating(input).squeeze(1) # [num_tokens, 1, num_experts] - + # Share the routing logic with the parent class to avoid code duplication. - # However, we pass dense_output=True to return dense [num_tokens, topk] tensors + # However, we pass dense_output=True to return dense [num_tokens, topk] tensors # instead of sparse [num_tokens, num_experts]. - + probs, top_indices = topk_routing_with_score_function( logits, self.topk, @@ -783,8 +787,8 @@ def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = No - probs: Normalized routing probabilities [num_tokens, topk] - top_indices: Selected expert indices [num_tokens, topk] """ - + if self.training or not self.is_inference_cuda_graphed_iteration: return super().forward(input, padding_mask) - - return self._forward(input, padding_mask) \ No newline at end of file + + return self._forward(input, padding_mask) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 8fd153c4677..be85ac86bde 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -11,22 +11,24 @@ on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. """ -import torch from typing import List, Optional -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.moe.token_dispatcher import ( - MoEAllGatherTokenDispatcher, -) -from megatron.core.transformer.transformer_config import TransformerConfig +import torch -from megatron.core.tensor_parallel import gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region -from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep from megatron.core.inference.communication.torch_symm_triton import ( are_tensors_nvls_eligible, multimem_all_gather_fused, multimem_reduce_scatter, ) +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from megatron.core.transformer.moe.token_dispatcher import MoEAllGatherTokenDispatcher +from megatron.core.transformer.transformer_config import TransformerConfig + class InferenceCUDAGraphTokenDispatcher(MoEAllGatherTokenDispatcher): """ @@ -69,10 +71,7 @@ def __init__( self.triton_nvls_kernels_allowed = not self.config.inference_disable_triton_nvls_kernels def _maybe_allocate_ag_buffers( - self, - routing_map: torch.Tensor, - probs: torch.Tensor, - hidden_states: torch.Tensor, + self, routing_map: torch.Tensor, probs: torch.Tensor, hidden_states: torch.Tensor ) -> dict: """ Allocate a single symmetric memory buffer for all-gather outputs of @@ -86,9 +85,12 @@ def _maybe_allocate_ag_buffers( """ _NONE = { "handle": None, - "routing_map": None, "routing_map_offset": 0, - "probs": None, "probs_offset": 0, - "hidden_states": None, "hidden_states_offset": 0, + "routing_map": None, + "routing_map_offset": 0, + "probs": None, + "probs_offset": 0, + "hidden_states": None, + "hidden_states_offset": 0, } local_tokens = probs.size(0) @@ -96,11 +98,13 @@ def _maybe_allocate_ag_buffers( topk = probs.size(-1) hidden_dim = hidden_states.size(-1) - result = get_global_symmetric_memory_buffer_ep().maybe_get_tensors([ - (global_tokens * topk, routing_map.dtype), - (global_tokens * topk, probs.dtype), - (global_tokens * hidden_dim, hidden_states.dtype), - ]) + result = get_global_symmetric_memory_buffer_ep().maybe_get_tensors( + [ + (global_tokens * topk, routing_map.dtype), + (global_tokens * topk, probs.dtype), + (global_tokens * hidden_dim, hidden_states.dtype), + ] + ) if result["handle"] is None: return _NONE @@ -108,9 +112,12 @@ def _maybe_allocate_ag_buffers( (rm_buf, rm_off), (p_buf, p_off), (hs_buf, hs_off) = result["tensors"] return { "handle": result["handle"], - "routing_map": rm_buf, "routing_map_offset": rm_off, - "probs": p_buf, "probs_offset": p_off, - "hidden_states": hs_buf, "hidden_states_offset": hs_off, + "routing_map": rm_buf, + "routing_map_offset": rm_off, + "probs": p_buf, + "probs_offset": p_off, + "hidden_states": hs_buf, + "hidden_states_offset": hs_off, } def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: @@ -132,9 +139,11 @@ def token_dispatch(self, hidden_states, probs): """ if self.ep_size == 1: return hidden_states, probs - + # 1. Check inputs only: if inputs are 16-byte divisible, outputs (world_size * input) are too. - nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible(hidden_states, probs, self.routing_map) + nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible( + hidden_states, probs, self.routing_map + ) ag_buffers = None if nvls_eligible: @@ -157,7 +166,9 @@ def token_dispatch(self, hidden_states, probs): # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors multimem_all_gather_fused( - ag_buffers["routing_map"].view(torch.bfloat16), # .view does not change the underlying data + ag_buffers["routing_map"].view( + torch.bfloat16 + ), # .view does not change the underlying data self.routing_map.view(torch.bfloat16), ag_buffers["routing_map_offset"], ag_buffers["probs"].view(torch.bfloat16), @@ -168,9 +179,13 @@ def token_dispatch(self, hidden_states, probs): ag_buffers["hidden_states_offset"], ag_buffers["handle"], ) - self.routing_map = ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) + self.routing_map = ( + ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) + ) probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) - hidden_states = ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) + hidden_states = ( + ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) + ) else: # Fallback to NCCL for all tensors with torch.no_grad(): @@ -184,7 +199,6 @@ def token_dispatch(self, hidden_states, probs): return hidden_states, probs - def dispatch_postprocess(self, hidden_states, probs): """Pass-through: returns unpermuted inputs and routing_map for InferenceGroupedMLP.""" return hidden_states, self.routing_map, probs @@ -213,9 +227,7 @@ def token_combine(self, hidden_states): # since if the smaller output is 16-byte divisible, the input is too. output_shape = list(hidden_states.size()) output_shape[0] = hidden_states.size(0) // self.ep_size - output = torch.empty( - output_shape, dtype=hidden_states.dtype, device=hidden_states.device - ) + output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) # Check output only: if output is 16-byte divisible, input (world_size * output) is too. nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible(output) @@ -234,9 +246,8 @@ def token_combine(self, hidden_states): multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) return output else: - # Fallback to NCCL + # Fallback to NCCL hidden_states = reduce_scatter_to_sequence_parallel_region( hidden_states, group=self.tp_ep_group ) return hidden_states - diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index b6f3b3468fe..4d26b206c94 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1090,16 +1090,13 @@ def __post_init__(self): if self.expert_model_parallel_size > 1 and self.num_moe_experts is None: raise ValueError("num_moe_experts must be non None to use expert-parallel.") - if self.transformer_impl == "inference_optimized" and self.num_moe_experts is not None: if self.expert_tensor_parallel_size > 1: raise ValueError( "Inference-optimized MoE layers does not support expert tensor parallelism." ) if self.moe_expert_capacity_factor is not None: - raise ValueError( - "Inference-optimized MoE layers only support dropless MoE " - ) + raise ValueError("Inference-optimized MoE layers only support dropless MoE ") if self.moe_router_padding_for_quantization: raise ValueError( "Inference-optimized MoE layers do not support padded routing map for quantization." @@ -1109,10 +1106,7 @@ def __post_init__(self): "Inference-optimized MoE requires --moe-router-dtype=fp32 " "to avoid costly dtype conversions during decode." ) - if ( - self.gated_linear_unit - and self.cuda_graph_impl != "none" - ): + if self.gated_linear_unit and self.cuda_graph_impl != "none": raise ValueError( "Inference-optimized MoE does not yet support CUDA graphs with gated " "linear units (SwiGLU/GeGLU) due to differences in weight layouts " @@ -2067,7 +2061,9 @@ def __post_init__(self): "inference_fuse_tp_communication is only supported " "for inference_optimized transformer implementation." ) - assert self.num_moe_experts is None, "--inference-fuse-tp-communication is not supported for MoE models." + assert ( + self.num_moe_experts is None + ), "--inference-fuse-tp-communication is not supported for MoE models." if self.inference_disable_triton_nvls_kernels: assert self.transformer_impl == "inference_optimized", ( diff --git a/megatron/core/utils.py b/megatron/core/utils.py index ed31b77ba04..e558a1e785f 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -733,7 +733,7 @@ def maybe_get_tensors(self, tensor_specs, alignment=16): Pack multiple tensors contiguously in the symmetric buffer with alignment. Each tensor's starting offset is aligned to `alignment` bytes (default 16 - for 128-bit multimem access). + for 128-bit multimem access). Args: tensor_specs: list of (numel, dtype) tuples. From 902dc6963cf13a83c3766615836b7db81b2a7af8 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 11:15:09 -0800 Subject: [PATCH 75/92] fix linting issues --- .../communication/torch_symm_triton/utils.py | 3 ++- megatron/core/transformer/moe/experts.py | 7 +++++-- megatron/core/transformer/moe/moe_layer.py | 9 +++++---- megatron/core/transformer/moe/router.py | 20 ++++++++++++------- .../moe/token_dispatcher_inference.py | 6 ++++-- .../core/transformer/transformer_config.py | 6 ++++-- 6 files changed, 33 insertions(+), 18 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/utils.py b/megatron/core/inference/communication/torch_symm_triton/utils.py index 5ace510f2b8..3cc6dd8dcc0 100644 --- a/megatron/core/inference/communication/torch_symm_triton/utils.py +++ b/megatron/core/inference/communication/torch_symm_triton/utils.py @@ -18,7 +18,8 @@ def is_device_nvls_capable(device: torch.device) -> bool: - """Check if the device supports NVLS (multicast) collectives. Requires CUDA Hopper+ (SM >= 9).""" + """Check if the device supports NVLS (multicast) collectives. + Requires CUDA Hopper+ (SM >= 9).""" return device.type == "cuda" and torch.cuda.get_device_properties(device).major >= 9 diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 0d0260e2c76..dc25cfbaa7f 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -730,6 +730,9 @@ def _apply_bias(intermediate_parallel, bias_parallel, tokens_per_expert, permute ) def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): + """ + Applies bias and activation function to the output of linear_fc1. + """ if self.config.use_te_activation_func: if bias_parallel is not None: intermediate_parallel = intermediate_parallel + bias_parallel @@ -1070,10 +1073,10 @@ def _torch_grouped_mm_forward( ) # Activation with routing probabilities - # intermediate_parallel = self._activation_func_with_probs(fc1_output, permuted_probs) bias_act_output = self.bias_act_func(fc1_output, None, permuted_probs) - # FC2: [total_tokens, ffn_hidden] @ [num_experts, hidden, ffn_hidden] -> [total_tokens, hidden] + # FC2: [total_tokens, ffn_hidden] @ [num_experts, hidden, ffn_hidden] + # -> [total_tokens, hidden] fc2_output = torch._grouped_mm( bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs ) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index c479b808a54..8c78bc58e35 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -34,7 +34,7 @@ from megatron.core.utils import internal_api try: - import flashinfer + import flashinfer # pylint: disable=unused-import HAVE_FLASHINFER = True except ImportError: @@ -42,8 +42,8 @@ if HAVE_FLASHINFER: try: - import flashinfer_cubin - import flashinfer_jit_cache + import flashinfer_cubin # pylint: disable=unused-import + import flashinfer_jit_cache # pylint: disable=unused-import HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = True except ImportError: @@ -274,7 +274,8 @@ def __init__( if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: warnings.warn( "flashinfer-cubin and/or flashinfer-jit-cache not found. " - "The FlashInfer cutlass kernel will be JIT compiled, which may take a long time." + "The FlashInfer cutlass kernel will be JIT compiled," + "which may take a long time." ) self._setup_inference_mode(pg_collection) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 723033a3015..9d663e8c430 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -737,19 +737,25 @@ def __init__( pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations. """ # Enforce constraints before calling super().__init__ - assert ( - config.moe_router_num_groups is None - ), f"InferenceTopKRouter requires moe_router_num_groups=None, got {config.moe_router_num_groups}" - assert config.moe_router_score_function in [ - "sigmoid", - "softmax", - ], f"InferenceTopKRouter requires moe_router_score_function in ['sigmoid', 'softmax'], got '{config.moe_router_score_function}'" + assert config.moe_router_num_groups is None, ( + f"InferenceTopKRouter requires moe_router_num_groups=None, " + f"got {config.moe_router_num_groups}" + ) + assert config.moe_router_score_function in ["sigmoid", "softmax"], ( + f"InferenceTopKRouter requires moe_router_score_function in " + f"['sigmoid', 'softmax'], got '{config.moe_router_score_function}'" + ) super().__init__(config=config, pg_collection=pg_collection) self.is_inference_cuda_graphed_iteration = False def set_is_inference_cuda_graphed_iteration(self, set_to: bool): + """Set whether the current iteration is being CUDA graphed. + + Args: + set_to: If True, the router will use CUDA graph-compatible operations. + """ self.is_inference_cuda_graphed_iteration = set_to @torch.compile() diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index be85ac86bde..0da618a5ba2 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -140,14 +140,16 @@ def token_dispatch(self, hidden_states, probs): if self.ep_size == 1: return hidden_states, probs - # 1. Check inputs only: if inputs are 16-byte divisible, outputs (world_size * input) are too. + # 1. Check inputs only: if inputs are 16-byte divisible, + # outputs (world_size * input) are too. nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible( hidden_states, probs, self.routing_map ) ag_buffers = None if nvls_eligible: - # 2. Now attempt to allocate symmetric memory buffers for all-gather outputs. If allocation fails, fallback to NCCL. + # 2. Now attempt to allocate symmetric memory buffers for + # all-gather outputs. If allocation fails, fallback to NCCL. ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) # 3. Can use NVLS if eligible and buffers allocated successfully (handle is not None) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 4d26b206c94..997751a53bc 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -862,7 +862,8 @@ class TransformerConfig(ModelParallelConfig): """ If true, disables the use of Triton NVLS kernels during inference. """ inference_disable_torch_grouped_mm: bool = False - """ If true, disables torch._grouped_mm in InferenceGroupedMLP, falling back to TE GroupedGEMM. """ + """ If true, disables torch._grouped_mm in InferenceGroupedMLP, + falling back to TE GroupedGEMM. """ mrope_section: Optional[List[int]] = None """ Multimodal rope section is for channel dimension of temporal, height and width @@ -1099,7 +1100,8 @@ def __post_init__(self): raise ValueError("Inference-optimized MoE layers only support dropless MoE ") if self.moe_router_padding_for_quantization: raise ValueError( - "Inference-optimized MoE layers do not support padded routing map for quantization." + "Inference-optimized MoE layers do not support padded " + "routing map for quantization." ) if self.moe_router_dtype != "fp32": raise ValueError( From 2547534b75ba1dc6b7114605ed695fba0ae1e603 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 11:46:11 -0800 Subject: [PATCH 76/92] linting --- .../core/inference/engines/dynamic_engine.py | 7 ++-- megatron/core/inference/utils.py | 27 ++++++++++----- megatron/core/transformer/moe/experts.py | 13 ++++--- megatron/core/transformer/moe/moe_layer.py | 34 +++++++++++++------ megatron/core/transformer/moe/router.py | 12 +++---- .../moe/token_dispatcher_inference.py | 4 +-- 6 files changed, 63 insertions(+), 34 deletions(-) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 3f6f9e86e44..d06ab5c53d2 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -44,7 +44,8 @@ from megatron.core.inference.utils import ( Counter, await_process_call, - set_is_inference_cuda_graphed_iteration_for_ep_inference, + set_inference_cuda_graphed_iteration_for_ep_inference, + unset_inference_cuda_graphed_iteration_for_ep_inference, ) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import delete_cuda_graphs @@ -303,7 +304,7 @@ def create_cuda_graphs(self, reset_context: bool = True): ) if is_inference_optimized_ep: unwrapped_model = controller.inference_wrapped_model.model - set_is_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model, True) + set_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) tbar = enumerate(context.cuda_graph_batch_dimensions_list) if HAVE_TQDM: @@ -334,7 +335,7 @@ def create_cuda_graphs(self, reset_context: bool = True): # Disable inference dispatcher after graph capture if is_inference_optimized_ep: - set_is_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model, False) + unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) # Memory usage. time_end = time.time() diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index 5810dda0b26..770a592e9c8 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -132,21 +132,32 @@ def set_decode_expert_padding(model, set_to: bool = False, capacity_factor: int router.config.moe_pad_expert_input_to_capacity = bool(set_to) -def set_is_inference_cuda_graphed_iteration_for_ep_inference(model, set_to: bool): +def set_inference_cuda_graphed_iteration_for_ep_inference(model): + """Enable CUDA graph compatibility for expert parallel inference. + + Sets a flag in all MoELayers indicating the current iteration is being + captured/executed in a CUDA graph. This allows the dispatcher to adjust + its behavior for CUDA graph compatibility. """ - Toggle CUDA graph compatibility for expert parallel inference. - This sets a boolean flag in all MoELayers to indicate whether - the current iteration is being captured/executed in a CUDA graph. - This allows the dispatcher to adjust its behavior for CUDA graph compatibility, - Args: - - set_to: Enable (True) or disable (False) CUDA graph compatibility. + global moe_layer_cache + if moe_layer_cache is None: + _init_moe_expert_cache(model) + + for moe_layer in moe_layer_cache: + moe_layer.set_inference_cuda_graphed_iteration() + + +def unset_inference_cuda_graphed_iteration_for_ep_inference(model): + """Disable CUDA graph compatibility for expert parallel inference. + + Clears the flag in all MoELayers, restoring standard dispatcher behavior. """ global moe_layer_cache if moe_layer_cache is None: _init_moe_expert_cache(model) for moe_layer in moe_layer_cache: - moe_layer.set_is_inference_cuda_graphed_iteration(set_to) + moe_layer.unset_inference_cuda_graphed_iteration() def tensor_swap(x, src_idxs, dst_idxs): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index dc25cfbaa7f..ee287e4af11 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -979,9 +979,13 @@ def _resolve_flashinfer_activation_type(self): return ActivationType.Relu2 raise ValueError(f"No FlashInfer ActivationType mapping for activation_func={func}") - def set_is_inference_cuda_graphed_iteration(self, set_to: bool): - """Toggle CUDA-graphed iteration mode.""" - self.is_inference_cuda_graphed_iteration = set_to + def set_inference_cuda_graphed_iteration(self): + """Enable CUDA-graphed iteration mode.""" + self.is_inference_cuda_graphed_iteration = True + + def unset_inference_cuda_graphed_iteration(self): + """Disable CUDA-graphed iteration mode.""" + self.is_inference_cuda_graphed_iteration = False def _build_concatenated_weights(self): """Create big contiguous weight tensors with per-expert views for checkpoint compatibility. @@ -1067,7 +1071,8 @@ def _torch_grouped_mm_forward( # offs[i] = end index of expert i's tokens offs = tokens_per_expert.cumsum(0).to(torch.int32) - # FC1: [total_tokens, hidden] @ [num_experts, ffn_hidden, hidden] -> [total_tokens, ffn_hidden] + # FC1: [total_tokens, hidden] @ [num_experts, ffn_hidden, hidden] + # -> [total_tokens, ffn_hidden] fc1_output = torch._grouped_mm( permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs ) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 8c78bc58e35..09c5e007bae 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -303,24 +303,36 @@ def _setup_inference_mode(self, pg_collection): pg_collection=pg_collection, ) - def set_is_inference_cuda_graphed_iteration(self, set_to: bool): - """Toggle CUDA-graphed iteration mode on this layer, its router, and its experts. + def set_inference_cuda_graphed_iteration(self): + """Enable CUDA-graphed iteration mode on this layer, its router, and its experts. - When enabled, swaps in the inference-optimized token dispatcher and disables - shared expert overlap. When disabled, restores the standard dispatcher. + Swaps in the inference-optimized token dispatcher and disables + shared expert overlap. """ - self.is_inference_cuda_graphed_iteration = set_to - if hasattr(self.router, "set_is_inference_cuda_graphed_iteration"): - self.router.set_is_inference_cuda_graphed_iteration(set_to) - if hasattr(self.experts, "set_is_inference_cuda_graphed_iteration"): - self.experts.set_is_inference_cuda_graphed_iteration(set_to) + self.is_inference_cuda_graphed_iteration = True + if hasattr(self.router, "set_inference_cuda_graphed_iteration"): + self.router.set_inference_cuda_graphed_iteration() + if hasattr(self.experts, "set_inference_cuda_graphed_iteration"): + self.experts.set_inference_cuda_graphed_iteration() - if set_to and self._inference_token_dispatcher is not None: + if self._inference_token_dispatcher is not None: self._saved_token_dispatcher = self.token_dispatcher self.token_dispatcher = self._inference_token_dispatcher self._saved_shared_expert_overlap = self.shared_expert_overlap self.shared_expert_overlap = False - elif not set_to and hasattr(self, "_saved_token_dispatcher"): + + def unset_inference_cuda_graphed_iteration(self): + """Disable CUDA-graphed iteration mode on this layer, its router, and its experts. + + Restores the standard token dispatcher and shared expert overlap setting. + """ + self.is_inference_cuda_graphed_iteration = False + if hasattr(self.router, "unset_inference_cuda_graphed_iteration"): + self.router.unset_inference_cuda_graphed_iteration() + if hasattr(self.experts, "unset_inference_cuda_graphed_iteration"): + self.experts.unset_inference_cuda_graphed_iteration() + + if hasattr(self, "_saved_token_dispatcher"): self.token_dispatcher = self._saved_token_dispatcher self.shared_expert_overlap = self._saved_shared_expert_overlap diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 9d663e8c430..a5a24b149ec 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -750,13 +750,13 @@ def __init__( self.is_inference_cuda_graphed_iteration = False - def set_is_inference_cuda_graphed_iteration(self, set_to: bool): - """Set whether the current iteration is being CUDA graphed. + def set_inference_cuda_graphed_iteration(self): + """Enable CUDA graph-compatible operations for the router.""" + self.is_inference_cuda_graphed_iteration = True - Args: - set_to: If True, the router will use CUDA graph-compatible operations. - """ - self.is_inference_cuda_graphed_iteration = set_to + def unset_inference_cuda_graphed_iteration(self): + """Disable CUDA graph-compatible operations for the router.""" + self.is_inference_cuda_graphed_iteration = False @torch.compile() def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 0da618a5ba2..31f03e30e65 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -35,8 +35,8 @@ class InferenceCUDAGraphTokenDispatcher(MoEAllGatherTokenDispatcher): CUDA-graph-compatible AllGather token dispatcher for inference. Only used during CUDA-graphed inference iterations. Swapped in by - MoELayer.set_is_inference_cuda_graphed_iteration() before graph capture - and swapped out after. + MoELayer.set_inference_cuda_graphed_iteration() before graph capture + and swapped out by MoELayer.unset_inference_cuda_graphed_iteration() after. Key features: - AllGather/ReduceScatter instead of AlltoAll for CUDA graph compatibility From 165d6d4534ef79ca3cc65ca0321694754bd0f257 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 11:50:55 -0800 Subject: [PATCH 77/92] refactor --- megatron/core/transformer/moe/experts.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index ee287e4af11..ee3f6cddd2c 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -1071,8 +1071,6 @@ def _torch_grouped_mm_forward( # offs[i] = end index of expert i's tokens offs = tokens_per_expert.cumsum(0).to(torch.int32) - # FC1: [total_tokens, hidden] @ [num_experts, ffn_hidden, hidden] - # -> [total_tokens, ffn_hidden] fc1_output = torch._grouped_mm( permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs ) @@ -1080,8 +1078,6 @@ def _torch_grouped_mm_forward( # Activation with routing probabilities bias_act_output = self.bias_act_func(fc1_output, None, permuted_probs) - # FC2: [total_tokens, ffn_hidden] @ [num_experts, hidden, ffn_hidden] - # -> [total_tokens, hidden] fc2_output = torch._grouped_mm( bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs ) From 14c8a397f60357c20bf966f621c8e237de2f63b8 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 12:28:45 -0800 Subject: [PATCH 78/92] fix unit test failures --- tests/unit_tests/inference/test_batch_dimension_utils.py | 2 +- tests/unit_tests/models/test_mamba_moe_model.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index 5bcb29a0f24..6d0ed756fc6 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -31,7 +31,7 @@ def _generate_graphs(num_cuda_graphs, use_non_decode=True): tp_size=TP_SIZE, num_cuda_graphs=num_cuda_graphs, cuda_graph_max_tokens=MAX_REQUESTS, - cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + cuda_graph_mixed_prefill_request_count=min(MIXED_PREFILL_COUNT, MAX_REQUESTS), max_requests=MAX_REQUESTS, max_tokens=MAX_TOKENS, max_sequence_length=MAX_SEQ_LEN, diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 93d66be0669..98c6ac63e0e 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -282,6 +282,8 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, + "inference_disable_torch_grouped_mm": False, + "inference_disable_triton_nvls_kernels": False, } # Fields to ignore entirely (ephemeral, environment-specific, very large). SKIP_FIELDS = set() From 2fc86b20aeb4b5a12c616fadf58a30a58fc4c643 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 16:26:00 -0800 Subject: [PATCH 79/92] unit test for inference top-k router --- .../inference/test_moe_inference.py | 222 ++++++++++++++++++ 1 file changed, 222 insertions(+) create mode 100644 tests/unit_tests/inference/test_moe_inference.py diff --git a/tests/unit_tests/inference/test_moe_inference.py b/tests/unit_tests/inference/test_moe_inference.py new file mode 100644 index 00000000000..45190eb914a --- /dev/null +++ b/tests/unit_tests/inference/test_moe_inference.py @@ -0,0 +1,222 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for inference-optimized MoE components. + +Config is modeled after nanov3 (Nemotron-6 3B Hybrid MoE) with smaller +dimensions for fast unit test execution: +- squared_relu activation (not swiglu/gated) +- sigmoid router score function with expert bias +- topk=6, topk_scaling_factor=2.5 +- shared experts +""" + +import pytest +import torch + +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import is_te_min_version, is_torch_min_version +from megatron.training.initialize import _set_random_seed +from megatron.core.activations import squared_relu +from tests.unit_tests.test_utilities import Utils + +# Reusable skip decorators +requires_te = pytest.mark.skipif( + not is_te_min_version("1.7.0.dev0"), + reason="Requires transformer-engine >= 1.7.0", +) +requires_torch_grouped_mm = pytest.mark.skipif( + not is_torch_min_version("2.10") or not hasattr(torch, '_grouped_mm'), + reason="Requires PyTorch >= 2.10 with torch._grouped_mm", +) + +# ────────────────────────────────────────────────────────────────────── +# NanoV3-like config (scaled down from 2688→128 hidden, 128→8 experts) +# ────────────────────────────────────────────────────────────────────── + +NANOV3_BASE = dict( + num_layers=1, + hidden_size=128, + ffn_hidden_size=128, + num_attention_heads=4, + num_query_groups=2, + num_moe_experts=8, + moe_ffn_hidden_size=128, + moe_router_topk=6, + moe_router_score_function="sigmoid", + moe_router_enable_expert_bias=True, + moe_router_topk_scaling_factor=2.5, + moe_shared_expert_intermediate_size=256, + moe_router_dtype='fp32', + moe_shared_expert_overlap=False, + moe_grouped_gemm=True, + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.01, + activation_func=squared_relu, + normalization="RMSNorm", + add_bias_linear=False, + bf16=True, + params_dtype=torch.bfloat16 +) + + +def _make_base_config(**overrides): + """Create a TransformerConfig with nanov3-like defaults.""" + params = {**NANOV3_BASE, **overrides} + return TransformerConfig(**params) + + +# ────────────────────────────────────────────────────────────────────── +# InferenceTopKRouter +# ────────────────────────────────────────────────────────────────────── + + +class TestInferenceTopKRouter: + + @classmethod + def setup_class(cls): + Utils.initialize_model_parallel(1, 1) + _set_random_seed(seed_=123, data_parallel_random_init=False) + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def _make_router(self, **config_overrides): + from megatron.core.transformer.moe.moe_utils import get_default_pg_collection + from megatron.core.transformer.moe.router import InferenceTopKRouter + + config = _make_base_config(**config_overrides) + return InferenceTopKRouter(config=config, pg_collection=get_default_pg_collection()).cuda().to(torch.bfloat16) + + def test_init_rejects_num_groups(self): + """InferenceTopKRouter requires moe_router_num_groups=None.""" + with pytest.raises(AssertionError, match="moe_router_num_groups"): + self._make_router(moe_router_num_groups=2) + + @pytest.mark.parametrize("score_fn", ["none", "invalid"]) + def test_init_rejects_unsupported_score_function(self, score_fn): + """InferenceTopKRouter requires sigmoid or softmax score function.""" + with pytest.raises(AssertionError, match="moe_router_score_function"): + self._make_router(moe_router_score_function=score_fn, + moe_router_enable_expert_bias=False) + + @pytest.mark.parametrize("score_fn", ["sigmoid", "softmax"]) + def test_init_accepts_valid_score_function(self, score_fn): + """InferenceTopKRouter accepts sigmoid and softmax.""" + # Expert bias only valid with sigmoid; disable it for softmax + enable_bias = score_fn == "sigmoid" + router = self._make_router( + moe_router_score_function=score_fn, + moe_router_enable_expert_bias=enable_bias, + ) + assert router is not None + + def test_set_unset_inference_mode(self): + """Toggle is_inference_cuda_graphed_iteration flag.""" + router = self._make_router() + assert not router.is_inference_cuda_graphed_iteration + + router.set_inference_cuda_graphed_iteration() + assert router.is_inference_cuda_graphed_iteration + + router.unset_inference_cuda_graphed_iteration() + assert not router.is_inference_cuda_graphed_iteration + + def test_training_mode_forward_returns_sparse(self): + """In training mode, forward delegates to parent and returns sparse tensors.""" + router = self._make_router() + router.train() + num_tokens = 16 + num_experts = NANOV3_BASE["num_moe_experts"] + + input_tensor = torch.randn( + num_tokens, NANOV3_BASE["hidden_size"], device="cuda", dtype=torch.bfloat16 + ) + probs, routing_map = router(input_tensor) + + # Parent TopKRouter returns [num_tokens, num_experts] sparse routing_map + assert routing_map.shape == (num_tokens, num_experts) + + def test_inference_vs_training_selects_same_experts(self): + """Inference and training modes should select the same top-k experts.""" + router = self._make_router() + num_tokens = 16 + topk = NANOV3_BASE["moe_router_topk"] + + input_tensor = torch.randn( + num_tokens, NANOV3_BASE["hidden_size"], device="cuda", dtype=torch.bfloat16 + ) + + # Training mode: get routing_map (sparse) and extract top expert indices + router.train() + _, routing_map = router(input_tensor.clone()) + # routing_map is [num_tokens, num_experts] bool + training_experts = set() + for i in range(num_tokens): + experts_for_token = routing_map[i].nonzero(as_tuple=True)[0] + for e in experts_for_token: + training_experts.add((i, e.item())) + + # Inference mode: get top_indices (dense) + router.eval() + router.set_inference_cuda_graphed_iteration() + _, top_indices = router(input_tensor.clone()) + + inference_experts = set() + for i in range(num_tokens): + for k in range(topk): + inference_experts.add((i, top_indices[i, k].item())) + + # Same expert selections + assert training_experts == inference_experts + + def test_cuda_graph_capture_and_replay(self): + """Router forward can be captured in a CUDA graph and replayed.""" + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + router = self._make_router() + router.eval() + router.set_inference_cuda_graphed_iteration() + + num_tokens = 16 + hidden_size = NANOV3_BASE["hidden_size"] + + # Static input buffer for CUDA graph (seeded for reproducibility) + static_input = torch.randn( + num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16 + ) + + # Warmup (required before CUDA graph capture) + with torch.no_grad(): + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + router(static_input) + torch.cuda.current_stream().wait_stream(s) + + # Capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + static_probs, static_indices = router(static_input) + + router.unset_inference_cuda_graphed_iteration() + + # Fill indices with -1, replay, check that the graph overwrote them + static_indices.fill_(-1) + static_input.copy_(torch.randn_like(static_input)) + graph.replay() + assert (static_indices != -1).all(), "Graph replay should overwrite all expert indices" + + expected_indices = [ + [2, 6, 4, 5, 3, 7], [4, 1, 3, 2, 6, 0], [4, 1, 3, 7, 5, 2], + [6, 0, 7, 5, 2, 4], [0, 7, 5, 1, 4, 2], [5, 6, 0, 7, 1, 4], + [6, 2, 0, 7, 4, 1], [0, 2, 1, 7, 4, 5], [0, 7, 5, 3, 1, 6], + [1, 4, 7, 3, 0, 6], [6, 7, 0, 2, 3, 1], [3, 0, 7, 6, 4, 2], + [6, 7, 0, 4, 1, 3], [1, 3, 6, 5, 0, 2], [6, 1, 0, 7, 3, 2], + [1, 5, 0, 4, 3, 7], + ] + assert static_indices.tolist() == expected_indices, ( + f"Expert indices mismatch:\n{static_indices.tolist()}\n!=\n{expected_indices}" + ) + From 04672e43422ca2347315a62b5be29d6d1c4e0a34 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 16:33:21 -0800 Subject: [PATCH 80/92] minor changes --- .../inference/test_moe_inference.py | 60 ++++++++++++------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/tests/unit_tests/inference/test_moe_inference.py b/tests/unit_tests/inference/test_moe_inference.py index 45190eb914a..a94b5046166 100644 --- a/tests/unit_tests/inference/test_moe_inference.py +++ b/tests/unit_tests/inference/test_moe_inference.py @@ -13,16 +13,15 @@ import pytest import torch +from megatron.core.activations import squared_relu from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version, is_torch_min_version from megatron.training.initialize import _set_random_seed -from megatron.core.activations import squared_relu from tests.unit_tests.test_utilities import Utils # Reusable skip decorators requires_te = pytest.mark.skipif( - not is_te_min_version("1.7.0.dev0"), - reason="Requires transformer-engine >= 1.7.0", + not is_te_min_version("1.7.0.dev0"), reason="Requires transformer-engine >= 1.7.0" ) requires_torch_grouped_mm = pytest.mark.skipif( not is_torch_min_version("2.10") or not hasattr(torch, '_grouped_mm'), @@ -55,7 +54,7 @@ normalization="RMSNorm", add_bias_linear=False, bf16=True, - params_dtype=torch.bfloat16 + params_dtype=torch.bfloat16, ) @@ -70,6 +69,7 @@ def _make_base_config(**overrides): # ────────────────────────────────────────────────────────────────────── +@pytest.mark.internal class TestInferenceTopKRouter: @classmethod @@ -86,19 +86,31 @@ def _make_router(self, **config_overrides): from megatron.core.transformer.moe.router import InferenceTopKRouter config = _make_base_config(**config_overrides) - return InferenceTopKRouter(config=config, pg_collection=get_default_pg_collection()).cuda().to(torch.bfloat16) + return ( + InferenceTopKRouter(config=config, pg_collection=get_default_pg_collection()) + .cuda() + .to(torch.bfloat16) + ) def test_init_rejects_num_groups(self): """InferenceTopKRouter requires moe_router_num_groups=None.""" with pytest.raises(AssertionError, match="moe_router_num_groups"): self._make_router(moe_router_num_groups=2) + def test_config_rejects_non_fp32_router_dtype(self): + """inference_optimized config requires moe_router_dtype='fp32'.""" + with pytest.raises(ValueError, match="moe-router-dtype"): + _make_base_config( + transformer_impl="inference_optimized", add_qkv_bias=False, moe_router_dtype=None + ) + @pytest.mark.parametrize("score_fn", ["none", "invalid"]) def test_init_rejects_unsupported_score_function(self, score_fn): """InferenceTopKRouter requires sigmoid or softmax score function.""" with pytest.raises(AssertionError, match="moe_router_score_function"): - self._make_router(moe_router_score_function=score_fn, - moe_router_enable_expert_bias=False) + self._make_router( + moe_router_score_function=score_fn, moe_router_enable_expert_bias=False + ) @pytest.mark.parametrize("score_fn", ["sigmoid", "softmax"]) def test_init_accepts_valid_score_function(self, score_fn): @@ -106,8 +118,7 @@ def test_init_accepts_valid_score_function(self, score_fn): # Expert bias only valid with sigmoid; disable it for softmax enable_bias = score_fn == "sigmoid" router = self._make_router( - moe_router_score_function=score_fn, - moe_router_enable_expert_bias=enable_bias, + moe_router_score_function=score_fn, moe_router_enable_expert_bias=enable_bias ) assert router is not None @@ -183,9 +194,7 @@ def test_cuda_graph_capture_and_replay(self): hidden_size = NANOV3_BASE["hidden_size"] # Static input buffer for CUDA graph (seeded for reproducibility) - static_input = torch.randn( - num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16 - ) + static_input = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) # Warmup (required before CUDA graph capture) with torch.no_grad(): @@ -209,14 +218,23 @@ def test_cuda_graph_capture_and_replay(self): assert (static_indices != -1).all(), "Graph replay should overwrite all expert indices" expected_indices = [ - [2, 6, 4, 5, 3, 7], [4, 1, 3, 2, 6, 0], [4, 1, 3, 7, 5, 2], - [6, 0, 7, 5, 2, 4], [0, 7, 5, 1, 4, 2], [5, 6, 0, 7, 1, 4], - [6, 2, 0, 7, 4, 1], [0, 2, 1, 7, 4, 5], [0, 7, 5, 3, 1, 6], - [1, 4, 7, 3, 0, 6], [6, 7, 0, 2, 3, 1], [3, 0, 7, 6, 4, 2], - [6, 7, 0, 4, 1, 3], [1, 3, 6, 5, 0, 2], [6, 1, 0, 7, 3, 2], + [2, 6, 4, 5, 3, 7], + [4, 1, 3, 2, 6, 0], + [4, 1, 3, 7, 5, 2], + [6, 0, 7, 5, 2, 4], + [0, 7, 5, 1, 4, 2], + [5, 6, 0, 7, 1, 4], + [6, 2, 0, 7, 4, 1], + [0, 2, 1, 7, 4, 5], + [0, 7, 5, 3, 1, 6], + [1, 4, 7, 3, 0, 6], + [6, 7, 0, 2, 3, 1], + [3, 0, 7, 6, 4, 2], + [6, 7, 0, 4, 1, 3], + [1, 3, 6, 5, 0, 2], + [6, 1, 0, 7, 3, 2], [1, 5, 0, 4, 3, 7], ] - assert static_indices.tolist() == expected_indices, ( - f"Expert indices mismatch:\n{static_indices.tolist()}\n!=\n{expected_indices}" - ) - + assert ( + static_indices.tolist() == expected_indices + ), f"Expert indices mismatch:\n{static_indices.tolist()}\n!=\n{expected_indices}" From 0f9f7a83354c34071c6f3f31771f019c2b0fcaaa Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 17:15:58 -0800 Subject: [PATCH 81/92] add warmup to router unit test --- megatron/core/transformer/moe/router.py | 4 +- .../inference/test_moe_inference.py | 126 +++++++++++++++++- 2 files changed, 126 insertions(+), 4 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index a5a24b149ec..d14e090994b 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -749,6 +749,7 @@ def __init__( super().__init__(config=config, pg_collection=pg_collection) self.is_inference_cuda_graphed_iteration = False + self.topk_routing_with_score_function = torch.compile(topk_routing_with_score_function) def set_inference_cuda_graphed_iteration(self): """Enable CUDA graph-compatible operations for the router.""" @@ -758,7 +759,6 @@ def unset_inference_cuda_graphed_iteration(self): """Disable CUDA graph-compatible operations for the router.""" self.is_inference_cuda_graphed_iteration = False - @torch.compile() def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): logits = self.gating(input).squeeze(1) # [num_tokens, 1, num_experts] @@ -766,7 +766,7 @@ def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = N # However, we pass dense_output=True to return dense [num_tokens, topk] tensors # instead of sparse [num_tokens, num_experts]. - probs, top_indices = topk_routing_with_score_function( + probs, top_indices = self.topk_routing_with_score_function( logits, self.topk, use_pre_softmax=self.config.moe_router_pre_softmax, diff --git a/tests/unit_tests/inference/test_moe_inference.py b/tests/unit_tests/inference/test_moe_inference.py index a94b5046166..9fde54d8c0f 100644 --- a/tests/unit_tests/inference/test_moe_inference.py +++ b/tests/unit_tests/inference/test_moe_inference.py @@ -55,6 +55,7 @@ add_bias_linear=False, bf16=True, params_dtype=torch.bfloat16, + transformer_impl="inference_optimized" ) @@ -182,7 +183,9 @@ def test_inference_vs_training_selects_same_experts(self): assert training_experts == inference_experts def test_cuda_graph_capture_and_replay(self): - """Router forward can be captured in a CUDA graph and replayed.""" + """Router forward can be captured in a CUDA graph and replayed. + Also checks for determinism by fixing the random seed and comparing against expected expert indices. + """ torch.manual_seed(42) torch.cuda.manual_seed(42) @@ -197,11 +200,13 @@ def test_cuda_graph_capture_and_replay(self): static_input = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) # Warmup (required before CUDA graph capture) + # 3 warmup iterations on a side stream with torch.no_grad(): s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): - router(static_input) + for _ in range(3): + router(static_input) torch.cuda.current_stream().wait_stream(s) # Capture @@ -238,3 +243,120 @@ def test_cuda_graph_capture_and_replay(self): assert ( static_indices.tolist() == expected_indices ), f"Expert indices mismatch:\n{static_indices.tolist()}\n!=\n{expected_indices}" + + +# ────────────────────────────────────────────────────────────────────── +# InferenceCUDAGraphTokenDispatcher +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.internal +class TestInferenceCUDAGraphTokenDispatcher: + + @classmethod + def setup_class(cls): + from megatron.core.parallel_state import _set_global_symmetric_memory_buffer + + Utils.initialize_model_parallel( + 1, 1, expert_model_parallel_size=Utils.world_size + ) + _set_random_seed(seed_=123, data_parallel_random_init=False) + _set_global_symmetric_memory_buffer() + + @classmethod + def teardown_class(cls): + from megatron.core.parallel_state import destroy_global_symmetric_memory_buffer + + destroy_global_symmetric_memory_buffer() + Utils.destroy_model_parallel() + + def _make_dispatcher(self, **config_overrides): + from megatron.core.transformer.moe.moe_utils import get_default_pg_collection + from megatron.core.transformer.moe.token_dispatcher_inference import ( + InferenceCUDAGraphTokenDispatcher, + ) + + config_overrides.setdefault( + "expert_model_parallel_size", Utils.world_size + ) + config = _make_base_config(**config_overrides) + num_local_experts = config.num_moe_experts // Utils.world_size + ep_rank = torch.distributed.get_rank() if Utils.world_size > 1 else 0 + local_expert_indices = [ + ep_rank * num_local_experts + i for i in range(num_local_experts) + ] + + return InferenceCUDAGraphTokenDispatcher( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=get_default_pg_collection(), + ) + + def test_init(self): + """Dispatcher can be constructed with nanov3-like config and EP=world_size.""" + dispatcher = self._make_dispatcher() + assert dispatcher.topk == NANOV3_BASE["moe_router_topk"] + assert dispatcher.ep_size == Utils.world_size + + def test_symmetric_memory_buffer_initialized(self): + """EP symmetric memory buffer is accessible after _set_global_symmetric_memory_buffer.""" + from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep + + buf = get_global_symmetric_memory_buffer_ep() + assert buf is not None + + @pytest.mark.parametrize("num_local_tokens", [2, 16, 128]) + def test_cuda_graph_dispatch_combine(self, num_local_tokens): + """Dispatch+combine can be captured in a CUDA graph and replayed. + Verifies shapes after AllGather expansion and ReduceScatter contraction, + and the round-trip property: combine(dispatch(x)) == x * ep_size. + All tensor byte sizes are 128-bit aligned for NVLS eligibility. + """ + dispatcher = self._make_dispatcher() + ep_size = dispatcher.ep_size + hidden_size = NANOV3_BASE["hidden_size"] + topk = NANOV3_BASE["moe_router_topk"] + num_experts = NANOV3_BASE["num_moe_experts"] + + # Static buffers for CUDA graph + static_hidden = torch.randn( + num_local_tokens, hidden_size, device="cuda", dtype=torch.bfloat16 + ) + static_probs = torch.rand( + num_local_tokens, topk, device="cuda", dtype=torch.float32 + ) + static_routing_map = torch.randint( + 0, num_experts, (num_local_tokens, topk), device="cuda" + ) + + # 3 warmup iterations on a side stream + with torch.no_grad(): + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + dispatcher.routing_map = static_routing_map + d_hidden, d_probs = dispatcher.token_dispatch(static_hidden, static_probs) + dispatcher.token_combine(d_hidden.clone()) + torch.cuda.current_stream().wait_stream(s) + + # Capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + dispatcher.routing_map = static_routing_map + graph_hidden, graph_probs = dispatcher.token_dispatch(static_hidden, static_probs) + graph_combined = dispatcher.token_combine(graph_hidden.clone()) + + # Verify shapes: dispatch expands by ep_size, combine shrinks back + assert graph_hidden.shape == (num_local_tokens * ep_size, hidden_size) + assert graph_probs.shape == (num_local_tokens * ep_size, topk) + assert graph_combined.shape == (num_local_tokens, hidden_size) + + # Replay with new data and verify round-trip + static_hidden.copy_(torch.randn_like(static_hidden)) + graph.replay() + + expected = (static_hidden * ep_size).to(torch.bfloat16) + torch.testing.assert_close(graph_combined, expected, atol=1e-3, rtol=1e-3) + From 84a1134122c917e263a4303ddbc97447ae64c6d7 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Fri, 27 Feb 2026 17:59:07 -0800 Subject: [PATCH 82/92] format --- megatron/core/transformer/moe/router.py | 2 +- .../inference/test_moe_inference.py | 208 ++++++++++-------- 2 files changed, 115 insertions(+), 95 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index d14e090994b..af9d23308b9 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -749,7 +749,7 @@ def __init__( super().__init__(config=config, pg_collection=pg_collection) self.is_inference_cuda_graphed_iteration = False - self.topk_routing_with_score_function = torch.compile(topk_routing_with_score_function) + self.topk_routing_with_score_function = torch.compile(topk_routing_with_score_function) def set_inference_cuda_graphed_iteration(self): """Enable CUDA graph-compatible operations for the router.""" diff --git a/tests/unit_tests/inference/test_moe_inference.py b/tests/unit_tests/inference/test_moe_inference.py index 9fde54d8c0f..4d515db2d30 100644 --- a/tests/unit_tests/inference/test_moe_inference.py +++ b/tests/unit_tests/inference/test_moe_inference.py @@ -14,6 +14,7 @@ import torch from megatron.core.activations import squared_relu +from megatron.core.inference.communication.torch_symm_triton import are_tensors_nvls_eligible from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_te_min_version, is_torch_min_version from megatron.training.initialize import _set_random_seed @@ -55,7 +56,7 @@ add_bias_linear=False, bf16=True, params_dtype=torch.bfloat16, - transformer_impl="inference_optimized" + transformer_impl="inference_optimized", ) @@ -182,68 +183,6 @@ def test_inference_vs_training_selects_same_experts(self): # Same expert selections assert training_experts == inference_experts - def test_cuda_graph_capture_and_replay(self): - """Router forward can be captured in a CUDA graph and replayed. - Also checks for determinism by fixing the random seed and comparing against expected expert indices. - """ - torch.manual_seed(42) - torch.cuda.manual_seed(42) - - router = self._make_router() - router.eval() - router.set_inference_cuda_graphed_iteration() - - num_tokens = 16 - hidden_size = NANOV3_BASE["hidden_size"] - - # Static input buffer for CUDA graph (seeded for reproducibility) - static_input = torch.randn(num_tokens, hidden_size, device="cuda", dtype=torch.bfloat16) - - # Warmup (required before CUDA graph capture) - # 3 warmup iterations on a side stream - with torch.no_grad(): - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(3): - router(static_input) - torch.cuda.current_stream().wait_stream(s) - - # Capture - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph): - static_probs, static_indices = router(static_input) - - router.unset_inference_cuda_graphed_iteration() - - # Fill indices with -1, replay, check that the graph overwrote them - static_indices.fill_(-1) - static_input.copy_(torch.randn_like(static_input)) - graph.replay() - assert (static_indices != -1).all(), "Graph replay should overwrite all expert indices" - - expected_indices = [ - [2, 6, 4, 5, 3, 7], - [4, 1, 3, 2, 6, 0], - [4, 1, 3, 7, 5, 2], - [6, 0, 7, 5, 2, 4], - [0, 7, 5, 1, 4, 2], - [5, 6, 0, 7, 1, 4], - [6, 2, 0, 7, 4, 1], - [0, 2, 1, 7, 4, 5], - [0, 7, 5, 3, 1, 6], - [1, 4, 7, 3, 0, 6], - [6, 7, 0, 2, 3, 1], - [3, 0, 7, 6, 4, 2], - [6, 7, 0, 4, 1, 3], - [1, 3, 6, 5, 0, 2], - [6, 1, 0, 7, 3, 2], - [1, 5, 0, 4, 3, 7], - ] - assert ( - static_indices.tolist() == expected_indices - ), f"Expert indices mismatch:\n{static_indices.tolist()}\n!=\n{expected_indices}" - # ────────────────────────────────────────────────────────────────────── # InferenceCUDAGraphTokenDispatcher @@ -257,9 +196,7 @@ class TestInferenceCUDAGraphTokenDispatcher: def setup_class(cls): from megatron.core.parallel_state import _set_global_symmetric_memory_buffer - Utils.initialize_model_parallel( - 1, 1, expert_model_parallel_size=Utils.world_size - ) + Utils.initialize_model_parallel(1, 1, expert_model_parallel_size=Utils.world_size) _set_random_seed(seed_=123, data_parallel_random_init=False) _set_global_symmetric_memory_buffer() @@ -276,15 +213,11 @@ def _make_dispatcher(self, **config_overrides): InferenceCUDAGraphTokenDispatcher, ) - config_overrides.setdefault( - "expert_model_parallel_size", Utils.world_size - ) + config_overrides.setdefault("expert_model_parallel_size", Utils.world_size) config = _make_base_config(**config_overrides) num_local_experts = config.num_moe_experts // Utils.world_size ep_rank = torch.distributed.get_rank() if Utils.world_size > 1 else 0 - local_expert_indices = [ - ep_rank * num_local_experts + i for i in range(num_local_experts) - ] + local_expert_indices = [ep_rank * num_local_experts + i for i in range(num_local_experts)] return InferenceCUDAGraphTokenDispatcher( num_local_experts=num_local_experts, @@ -306,29 +239,103 @@ def test_symmetric_memory_buffer_initialized(self): buf = get_global_symmetric_memory_buffer_ep() assert buf is not None - @pytest.mark.parametrize("num_local_tokens", [2, 16, 128]) - def test_cuda_graph_dispatch_combine(self, num_local_tokens): + @pytest.mark.parametrize("seed", [42, 123, 7]) + @pytest.mark.parametrize( + "num_local_tokens", + [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 40, + 48, + 56, + 64, + 72, + 80, + 88, + 96, + 104, + 112, + 120, + 128, + 136, + 144, + 152, + 160, + 168, + 176, + 184, + 192, + 200, + 208, + 216, + 224, + 232, + 240, + 248, + 256, + 272, + 288, + 304, + 320, + 336, + 352, + 368, + 384, + 400, + 416, + 432, + 448, + 464, + 480, + 496, + 512, + ], + ) + def test_cuda_graph_dispatch_combine(self, num_local_tokens, seed): """Dispatch+combine can be captured in a CUDA graph and replayed. - Verifies shapes after AllGather expansion and ReduceScatter contraction, - and the round-trip property: combine(dispatch(x)) == x * ep_size. + Creates global buffers, shards per rank, and verifies: + - NVLS AllGather output matches the full globalwol buffer + - NVLS ReduceScatter output matches fp32-accumulated reference All tensor byte sizes are 128-bit aligned for NVLS eligibility. """ + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + dispatcher = self._make_dispatcher() ep_size = dispatcher.ep_size hidden_size = NANOV3_BASE["hidden_size"] topk = NANOV3_BASE["moe_router_topk"] num_experts = NANOV3_BASE["num_moe_experts"] + rank = torch.distributed.get_rank() if ep_size > 1 else 0 + num_global_tokens = num_local_tokens * ep_size - # Static buffers for CUDA graph - static_hidden = torch.randn( - num_local_tokens, hidden_size, device="cuda", dtype=torch.bfloat16 - ) - static_probs = torch.rand( - num_local_tokens, topk, device="cuda", dtype=torch.float32 - ) - static_routing_map = torch.randint( - 0, num_experts, (num_local_tokens, topk), device="cuda" + # Create global buffers on rank 0 and broadcast to all ranks + global_hidden = torch.randn( + num_global_tokens, hidden_size, device="cuda", dtype=torch.bfloat16 ) + global_probs = torch.randn(num_global_tokens, topk, device="cuda", dtype=torch.float32) + global_routing_map = torch.randint(0, num_experts, (num_global_tokens, topk), device="cuda") + torch.distributed.broadcast(global_hidden, src=0) + torch.distributed.broadcast(global_probs, src=0) + torch.distributed.broadcast(global_routing_map, src=0) + + # Each rank grabs their own shard + start = rank * num_local_tokens + end = start + num_local_tokens + static_hidden = global_hidden[start:end].contiguous() + static_probs = global_probs[start:end].contiguous() + static_routing_map = global_routing_map[start:end].contiguous() + + if not are_tensors_nvls_eligible(static_hidden, static_probs, static_routing_map): + pytest.skip( + "Tensors are not NVLS-eligible (need SM>=9 and each tensor's memory to be a multiple of 16 bytes)" + ) # 3 warmup iterations on a side stream with torch.no_grad(): @@ -338,6 +345,9 @@ def test_cuda_graph_dispatch_combine(self, num_local_tokens): for _ in range(3): dispatcher.routing_map = static_routing_map d_hidden, d_probs = dispatcher.token_dispatch(static_hidden, static_probs) + d_hidden = d_hidden.clone() + d_probs = d_probs.clone() + dispatcher.routing_map = dispatcher.routing_map.clone() dispatcher.token_combine(d_hidden.clone()) torch.cuda.current_stream().wait_stream(s) @@ -345,18 +355,28 @@ def test_cuda_graph_dispatch_combine(self, num_local_tokens): graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): dispatcher.routing_map = static_routing_map - graph_hidden, graph_probs = dispatcher.token_dispatch(static_hidden, static_probs) - graph_combined = dispatcher.token_combine(graph_hidden.clone()) + d_hidden, d_probs = dispatcher.token_dispatch(static_hidden, static_probs) + graph_hidden = d_hidden.clone() + graph_probs = d_probs.clone() + graph_routing_map = dispatcher.routing_map.clone() + graph_combined = dispatcher.token_combine(d_hidden.clone()) # Verify shapes: dispatch expands by ep_size, combine shrinks back - assert graph_hidden.shape == (num_local_tokens * ep_size, hidden_size) - assert graph_probs.shape == (num_local_tokens * ep_size, topk) + assert graph_hidden.shape == (num_global_tokens, hidden_size) + assert graph_probs.shape == (num_global_tokens, topk) assert graph_combined.shape == (num_local_tokens, hidden_size) - # Replay with new data and verify round-trip - static_hidden.copy_(torch.randn_like(static_hidden)) + # Replay graph.replay() - expected = (static_hidden * ep_size).to(torch.bfloat16) - torch.testing.assert_close(graph_combined, expected, atol=1e-3, rtol=1e-3) - + # Verify AllGather: all gathered tensors should match global buffers + torch.testing.assert_close(graph_hidden, global_hidden, atol=0, rtol=0) + torch.testing.assert_close(graph_probs, global_probs, atol=0, rtol=0) + torch.testing.assert_close(graph_routing_map, global_routing_map, atol=0, rtol=0) + + # Verify ReduceScatter: all ranks have the same all-gathered data, so + # rank r gets ep_size * chunk_r. Compute reference in fp32 then downcast. + # Exact match (atol=0, rtol=0) is possible because the NVLS triton kernels + # accumulate in fp32 before writing bf16 output. + expected_combined = (global_hidden[start:end].float() * ep_size).bfloat16() + torch.testing.assert_close(graph_combined, expected_combined, atol=0, rtol=0) From f132d597c964a161448d7bc28f678be045daf7dd Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Mar 2026 22:14:10 -0800 Subject: [PATCH 83/92] add error message to assert --- megatron/core/transformer/moe/experts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index ee3f6cddd2c..ef01de77bb3 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -762,7 +762,7 @@ def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): else: raise ValueError("Only support fusion of swiglu and quick_gelu in TEGroupedMLP.") elif self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu: - assert bias_parallel is None + assert bias_parallel is None, "Bias is not supported with fused weighted squared relu." intermediate_parallel = weighted_squared_relu_impl( intermediate_parallel, permuted_probs ) From 264cbb2b005c79a9e838e826ccdcd105d38b35ac Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Mar 2026 22:22:30 -0800 Subject: [PATCH 84/92] address feedback --- .../communication/torch_symm_triton/collectives.py | 8 -------- megatron/core/models/mamba/mamba_layer_specs.py | 2 +- megatron/core/transformer/moe/experts.py | 9 +++++---- 3 files changed, 6 insertions(+), 13 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 1289fd54d60..48f475ed7d4 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -25,9 +25,6 @@ from .multimem_asm import ld_128, st_128 from .utils import are_tensors_nvls_eligible, get_flat_tid, sync_threads -# ── Triton kernels ───────────────────────────────────────────────────────── - - @triton.jit def _ag_phase( local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE @@ -39,11 +36,6 @@ def _ag_phase( Each thread handles 128-bit (NUMEL_PER_THREAD elements) at a time. byte_offset locates the tensor within the multicast buffer. - NOTE: When numel is not divisible by (NUMEL_PER_THREAD * WORLD_SIZE), the kernel - rounds up via cdiv and may read/write up to 15 bytes past the logical tensor end. - This is safe because PyTorch's CUDA caching allocator guarantees a minimum block - size of 512 bytes (kMinBlockSize in CUDACachingAllocator.cpp), so small tensors - always have sufficient backing memory. """ pid = tl.program_id(axis=0) tid = get_flat_tid() diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 044eefa730f..957f20847fc 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -179,7 +179,7 @@ ), ), moe_layer=ModuleSpec( - # Use inference-optimized MoE layer for better CUDA graph support + # Use inference-optimized MoE layer for end-to-end CUDA graph support module=TransformerLayer, submodules=TransformerLayerSubmodules( pre_mlp_layernorm=TENorm, mlp=moe_inference, mlp_bda=get_bias_dropout_add diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index ef01de77bb3..10027de7d7e 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -944,9 +944,11 @@ def __init__( pg_collection=pg_collection, ) - # Concatenate TE's per-expert weights into single tensors for torch._grouped_mm - # TE GroupedLinear stores weights as weight0, weight1, ..., weight{n-1} - # torch._grouped_mm expects shape [num_experts, out_features, in_features] + # TE's GroupedLinear stores per-expert weights as separate parameters + # (weight0, weight1, ..., weight{n-1}). We stack them into contiguous tensors + # of shape [num_experts, out_features, in_features] for torch._grouped_mm and + # FlashInfer's cutlass_fused_moe. Per-expert views are registered so that + # load_state_dict still writes into the contiguous buffers. self._build_concatenated_weights() self.is_inference_cuda_graphed_iteration = False @@ -1050,7 +1052,6 @@ def _torch_grouped_mm_forward( self, permuted_local_hidden_states, tokens_per_expert, permuted_probs ): permuted_probs = permuted_probs.unsqueeze(-1) - # assert tokens_per_expert.is_cuda, "tokens_per_expert must be on GPU" if not tokens_per_expert.is_cuda: tokens_per_expert = tokens_per_expert.to('cuda') From df9de352fd8c5215b2e7aeaddf1e78fb987ebcda Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Mar 2026 22:31:34 -0800 Subject: [PATCH 85/92] use decorator for torch compile --- megatron/core/transformer/moe/router.py | 28 ++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index af9d23308b9..a8cbd9839a5 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -749,7 +749,6 @@ def __init__( super().__init__(config=config, pg_collection=pg_collection) self.is_inference_cuda_graphed_iteration = False - self.topk_routing_with_score_function = torch.compile(topk_routing_with_score_function) def set_inference_cuda_graphed_iteration(self): """Enable CUDA graph-compatible operations for the router.""" @@ -759,14 +758,29 @@ def unset_inference_cuda_graphed_iteration(self): """Disable CUDA graph-compatible operations for the router.""" self.is_inference_cuda_graphed_iteration = False - def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): - logits = self.gating(input).squeeze(1) # [num_tokens, 1, num_experts] + @staticmethod + @torch.compile + def _compiled_topk_routing(logits, topk, use_pre_softmax, num_groups, + group_topk, scaling_factor, score_function, + expert_bias, fused, router_replay, dense_output): + return topk_routing_with_score_function( + logits, + topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + fused=fused, + router_replay=router_replay, + dense_output=dense_output, + ) - # Share the routing logic with the parent class to avoid code duplication. - # However, we pass dense_output=True to return dense [num_tokens, topk] tensors - # instead of sparse [num_tokens, num_experts]. + def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + logits = self.gating(input).squeeze(1) # [num_tokens, num_experts] - probs, top_indices = self.topk_routing_with_score_function( + probs, top_indices = self._compiled_topk_routing( logits, self.topk, use_pre_softmax=self.config.moe_router_pre_softmax, From ca75b2b7dc5c0f70cac9da5ade22ee1719f5eb70 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Mar 2026 22:46:49 -0800 Subject: [PATCH 86/92] bugfix --- megatron/core/inference/batch_dimensions_utils.py | 4 ++-- megatron/core/inference/contexts/dynamic_context.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 13b6e41f6c6..a7b325ca2ba 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -176,9 +176,9 @@ def adjust_batch_dims_for_expert_parallelism( is_any_ep_rank_in_non_decode = sync_tensor[1].item() == 1 # We force eager mode for scenarios where some ranks will run with CUDA graphs - # while others will not. Without this check, the all-to-all communication in the + # while others will not. Without this check, communication in the # expert routing layer would pad up to the maximum capacity only for the ranks that - # are using CUDA graphs in this step, leading to a NCCL hang. + # are using CUDA graphs in this step, leading to a hang. # This can happen in the following cases: # 1. If we only allow decode CUDA graphs but some ranks are running non-decode batches # 2. Some ranks are running explicit chunked prefill requests diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index e311e4b0600..fbe2eb76bf0 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -547,7 +547,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) self.smallest_non_decode_cuda_graph_size = ( - min(inference_config.cuda_graph_mixed_prefill_count, self.max_requests), + min(inference_config.cuda_graph_mixed_prefill_count, self.max_requests) ) self._using_cuda_graph_this_step = False From a61aea523196f7e34220da3eee4bf5898f7df2e6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Mar 2026 23:33:40 -0800 Subject: [PATCH 87/92] lint --- .../torch_symm_triton/collectives.py | 1 + .../inference/contexts/dynamic_context.py | 4 +-- megatron/core/transformer/moe/experts.py | 34 ++++++++++++++----- megatron/core/transformer/moe/moe_layer.py | 10 +++++- megatron/core/transformer/moe/router.py | 16 +++++++-- .../moe/token_dispatcher_inference.py | 8 +++-- 6 files changed, 57 insertions(+), 16 deletions(-) diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 48f475ed7d4..cf2003c8595 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -25,6 +25,7 @@ from .multimem_asm import ld_128, st_128 from .utils import are_tensors_nvls_eligible, get_flat_tid, sync_threads + @triton.jit def _ag_phase( local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index fbe2eb76bf0..a0f39b13075 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -546,8 +546,8 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - self.smallest_non_decode_cuda_graph_size = ( - min(inference_config.cuda_graph_mixed_prefill_count, self.max_requests) + self.smallest_non_decode_cuda_graph_size = min( + inference_config.cuda_graph_mixed_prefill_count, self.max_requests ) self._using_cuda_graph_this_step = False diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 10027de7d7e..312d113ab36 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -964,12 +964,14 @@ def __init__( # order, but TE stores them as [activation, gate]. Until FlashInfer supports # TE's weight ordering, the FlashInfer path is only available for non-gated # activations (e.g. squared_relu). - self._flashinfer_available = HAVE_FLASHINFER and not config.gated_linear_unit - if self._flashinfer_available: + if HAVE_FLASHINFER: self._flashinfer_activation_type = self._resolve_flashinfer_activation_type() def _resolve_flashinfer_activation_type(self): """Map megatron activation config to FlashInfer ActivationType.""" + assert ( + HAVE_FLASHINFER + ), "flashinfer-python is required to resolve FlashInfer activation type." func = self.config.activation_func if func == F.silu: return ActivationType.Silu @@ -1091,20 +1093,36 @@ def _torch_grouped_mm_forward( def forward( self, permuted_local_hidden_states: torch.Tensor, - tokens_per_expert: torch.Tensor, + tokens_per_expert: Optional[torch.Tensor], permuted_probs: torch.Tensor, + routing_map: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Forward pass with three modes: - - Training: delegates to parent TEGroupedMLP - - Inference + CUDA graphed: FlashInfer cutlass_fused_moe - - Inference + eager: torch._grouped_mm with GPU-resident offsets + + - Training: delegates to parent TEGroupedMLP. + - Inference + CUDA graphed: FlashInfer cutlass_fused_moe. tokens_per_expert + is not used in this path; the FlashInfer kernel operates directly on + routing_map. + - Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets. + + Args: + permuted_local_hidden_states: [num_tokens, hidden_size] input hidden states. + tokens_per_expert: [num_experts] number of tokens routed to each expert. + None when using the CUDA-graphed FlashInfer path. + permuted_probs: [num_tokens, topk] routing probabilities. + routing_map: [num_tokens, topk] token-to-expert assignment indices. + Required for the FlashInfer CUDA-graphed path, None otherwise. """ if self.training: return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) - elif self.is_inference_cuda_graphed_iteration and self._flashinfer_available: + elif self.is_inference_cuda_graphed_iteration: + assert routing_map is not None, "routing_map is required for FlashInfer forward pass." + assert ( + HAVE_FLASHINFER + ), "FlashInfer is not available; cannot use FlashInfer forward pass." return self._flashinfer_forward( - permuted_local_hidden_states, tokens_per_expert, permuted_probs + permuted_local_hidden_states, routing_map, permuted_probs ) elif self._torch_grouped_mm_available: diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 09c5e007bae..03308a1cfdf 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -414,7 +414,15 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso dispatched_input, tokens_per_expert, permuted_probs = ( self.token_dispatcher.dispatch_postprocess(hidden_states, probs) ) - expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert, permuted_probs) + if self.is_inference_cuda_graphed_iteration: + routing_map = self.token_dispatcher.routing_map + expert_output, mlp_bias = self.experts( + dispatched_input, tokens_per_expert, permuted_probs, routing_map=routing_map + ) + else: + expert_output, mlp_bias = self.experts( + dispatched_input, tokens_per_expert, permuted_probs + ) assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}" output = self.token_dispatcher.combine_preprocess(expert_output) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index a8cbd9839a5..f90cb700607 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -760,9 +760,19 @@ def unset_inference_cuda_graphed_iteration(self): @staticmethod @torch.compile - def _compiled_topk_routing(logits, topk, use_pre_softmax, num_groups, - group_topk, scaling_factor, score_function, - expert_bias, fused, router_replay, dense_output): + def _compiled_topk_routing( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + fused, + router_replay, + dense_output, + ): return topk_routing_with_score_function( logits, topk, diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 31f03e30e65..55a751230bb 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -202,8 +202,12 @@ def token_dispatch(self, hidden_states, probs): return hidden_states, probs def dispatch_postprocess(self, hidden_states, probs): - """Pass-through: returns unpermuted inputs and routing_map for InferenceGroupedMLP.""" - return hidden_states, self.routing_map, probs + """Pass-through: returns unpermuted inputs directly. + + tokens_per_expert is not computed by this inference dispatcher. Instead, + the FlashInfer fused MoE kernel operates directly on the routing map. + """ + return hidden_states, None, probs def combine_preprocess(self, expert_output): """Pass-through: InferenceGroupedMLP already produces unpermuted output.""" From 4fd23ce0c7beaae277d5467cf75e8a6d412fda3f Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Sun, 1 Mar 2026 23:41:12 -0800 Subject: [PATCH 88/92] format and guard properly --- megatron/core/transformer/moe/moe_layer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 03308a1cfdf..8277486b03b 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -414,7 +414,10 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso dispatched_input, tokens_per_expert, permuted_probs = ( self.token_dispatcher.dispatch_postprocess(hidden_states, probs) ) - if self.is_inference_cuda_graphed_iteration: + if ( + hasattr(self, "_inference_token_dispatcher") + and self.is_inference_cuda_graphed_iteration + ): routing_map = self.token_dispatcher.routing_map expert_output, mlp_bias = self.experts( dispatched_input, tokens_per_expert, permuted_probs, routing_map=routing_map From b1530a6abc951e9aad67f143112c4f7a55b81fe4 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Mar 2026 12:16:03 -0800 Subject: [PATCH 89/92] fix comments --- megatron/core/transformer/moe/experts.py | 7 +- megatron/core/transformer/moe/moe_utils.py | 14 +- megatron/core/transformer/moe/router.py | 2 +- .../moe/token_dispatcher_inference.py | 123 ++++++++++++++---- 4 files changed, 107 insertions(+), 39 deletions(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 312d113ab36..0fc954db4af 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -960,10 +960,6 @@ def __init__( and not config.inference_disable_torch_grouped_mm ) - # FlashInfer's cutlass_fused_moe expects gated weights in [gate, activation] - # order, but TE stores them as [activation, gate]. Until FlashInfer supports - # TE's weight ordering, the FlashInfer path is only available for non-gated - # activations (e.g. squared_relu). if HAVE_FLASHINFER: self._flashinfer_activation_type = self._resolve_flashinfer_activation_type() @@ -1000,8 +996,7 @@ def _build_concatenated_weights(self): This allows: - load_state_dict to load into weight{i} views -> writes into big tensor - - forward() to use big tensor directly with torch._grouped_mm - - No post-load hooks needed + - forward() to use big tensor directly with torch._grouped_mm or FlashInfer """ # Get device/dtype from existing TE weights device = self.linear_fc1.weight0.device diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 849e6bec4c7..51c8b51134f 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -691,11 +691,17 @@ def topk_routing_with_score_function( Returns: Tuple[torch.Tensor, torch.Tensor]: When dense_output=False (default): - - routing_probs (torch.Tensor): Shape [num_tokens, num_experts]. - - routing_map (torch.Tensor): Shape [num_tokens, num_experts]. + - routing_probs (torch.Tensor): Shape [num_tokens, num_experts]. Sparse tensor + containing the normalized routing probability for each token-expert pair. Non-zero + entries correspond to the top-k selected experts per token. + - routing_map (torch.Tensor): Shape [num_tokens, num_experts]. Boolean mask where + True indicates the token is routed to that expert (i.e. the expert was in the + token's top-k selection). When dense_output=True: - - probs (torch.Tensor): Shape [num_tokens, topk]. - - top_indices (torch.Tensor): Shape [num_tokens, topk]. + - probs (torch.Tensor): Shape [num_tokens, topk]. The normalized routing + probabilities for each token's top-k selected experts. + - top_indices (torch.Tensor): Shape [num_tokens, topk]. The expert indices + selected for each token. """ assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." num_tokens, num_experts = logits.shape diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index f90cb700607..45cfb59cd40 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -718,7 +718,7 @@ class InferenceTopKRouter(TopKRouter): A stripped-down version of TopKRouter that skips z-loss, auxiliary load balancing losses, token dropping, and expert bias updates. The _forward() method is @torch.compile()'d and returns dense [num_tokens, topk] tensors - instead of sparse [num_tokens, num_experts] for CUDA graph compatibility. + instead of sparse [num_tokens, num_experts] for compatibility with FlashInfer. Falls back to the parent TopKRouter.forward() for training or non-CUDA-graphed inference iterations. diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 55a751230bb..6b851c252c5 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -8,7 +8,7 @@ GPU-resident to avoid host synchronizations that would break CUDA graph capture. Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) -on Hopper+ GPUs with BF16, with automatic fallback to NCCL via superclass methods. +on Hopper+ GPUs with BF16, with automatic fallback to NCCL. """ from typing import List, Optional @@ -52,7 +52,7 @@ def __init__( pg_collection: Optional[ProcessGroupCollection] = None, ) -> None: """ - Initialize the inference AllGather token dispatcher. + Initialize the InferenceCUDAGraphTokenDispatcher. Args: num_local_experts: Number of experts on this rank. @@ -73,15 +73,31 @@ def __init__( def _maybe_allocate_ag_buffers( self, routing_map: torch.Tensor, probs: torch.Tensor, hidden_states: torch.Tensor ) -> dict: - """ - Allocate a single symmetric memory buffer for all-gather outputs of - routing_map, probs and hidden_states. Returns sliced views for each. - - Returns dict with: - - "handle": symmetric memory handle (or None if unavailable) - - "routing_map" / "routing_map_offset": raw byte view and byte offset - - "probs" / "probs_offset": raw byte view and byte offset - - "hidden_states" / "hidden_states_offset": raw byte view and byte offset + """Allocate a single symmetric memory output buffer for fused all-gather. + + Creates one contiguous symmetric memory buffer sized for the gathered + (global) routing_map, probs, and hidden_states, then returns sliced views + into it. This allows a single fused NVLS all-gather kernel to write all + three outputs in one launch. + + Args: + routing_map (torch.Tensor): Local routing map, shape [local_tokens, topk]. + Boolean or integer tensor mapping each token to its selected experts. + probs (torch.Tensor): Local routing probabilities, shape [local_tokens, topk]. + Normalized weights for each token's selected experts. + hidden_states (torch.Tensor): Local hidden states, shape [local_tokens, hidden_dim]. + + Returns: + dict: A dictionary with the following keys: + - "handle": Symmetric memory handle for NVLS ops, or None if + symmetric memory is unavailable. + - "routing_map": Raw byte view for the gathered routing map output. + - "routing_map_offset": Byte offset of routing_map within the buffer. + - "probs": Raw byte view for the gathered probs output. + - "probs_offset": Byte offset of probs within the buffer. + - "hidden_states": Raw byte view for the gathered hidden states output. + - "hidden_states_offset": Byte offset of hidden_states within the buffer. + When allocation fails, all tensor views are None and offsets are 0. """ _NONE = { "handle": None, @@ -121,9 +137,18 @@ def _maybe_allocate_ag_buffers( } def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: - """ - Allocate symmetric memory buffer for reduce-scatter input. - Input shape matches x (the unpermuted hidden states). + """Allocate a symmetric memory buffer for reduce-scatter input. + + The buffer has the same shape and dtype as x so that x can be copied + into it before the NVLS reduce-scatter kernel. + + Args: + x (torch.Tensor): The global hidden states to be reduce-scattered, + shape [global_tokens, hidden_dim]. + + Returns: + dict: A dictionary with keys "handle" (symmetric memory handle, or + None if unavailable) and "tensor" (the allocated buffer, or None). """ symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( list(x.size()), dtype=x.dtype @@ -131,11 +156,26 @@ def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: return symm_mem_buffer def token_dispatch(self, hidden_states, probs): - """ - Gathers tokens from all EP ranks using AllGather. + """Gathers tokens from all EP ranks using AllGather. + + Performs all-gather on routing_map (stored in self.routing_map), probs, + and hidden_states so that every rank holds the full global view. + Uses latency-optimized fused NVLS multimem_all_gather on Hopper+ GPUs + with BF16 when symmetric memory is available. Falls back to NCCL otherwise. - Uses latency-optimized NVLS multimem_all_gather for routing_map, probs and hidden_states - on Hopper+ GPUs with BF16. Falls back to NCCL otherwise. + Args: + hidden_states (torch.Tensor): Local hidden states, + shape [local_tokens, hidden_dim]. + probs (torch.Tensor): Local routing probabilities, + shape [local_tokens, topk]. Normalized weights for each token's + selected experts. + + Returns: + tuple: (hidden_states, probs) gathered across all EP ranks. + - hidden_states (torch.Tensor): Shape [global_tokens, hidden_dim]. + - probs (torch.Tensor): Shape [global_tokens, topk]. + Also updates self.routing_map in-place to the gathered + shape [global_tokens, topk]. """ if self.ep_size == 1: return hidden_states, probs @@ -202,29 +242,56 @@ def token_dispatch(self, hidden_states, probs): return hidden_states, probs def dispatch_postprocess(self, hidden_states, probs): - """Pass-through: returns unpermuted inputs directly. + """Pass-through: returns inputs directly without permutation. + + Unlike the training dispatcher, this does not permute tokens or compute + tokens_per_expert. The downstream InferenceGroupedMLP (FlashInfer / + CUTLASS fused MoE kernel) operates directly on the routing map stored + in self.routing_map. - tokens_per_expert is not computed by this inference dispatcher. Instead, - the FlashInfer fused MoE kernel operates directly on the routing map. + Args: + hidden_states (torch.Tensor): Gathered hidden states, + shape [global_tokens, hidden_dim]. + probs (torch.Tensor): Gathered routing probabilities, + shape [global_tokens, topk]. + + Returns: + tuple: (hidden_states, tokens_per_expert, probs) where + tokens_per_expert is always None. """ return hidden_states, None, probs def combine_preprocess(self, expert_output): - """Pass-through: InferenceGroupedMLP already produces unpermuted output.""" + """Pass-through: InferenceGroupedMLP already produces unpermuted output. + + No unpermutation is needed because dispatch_postprocess did not permute + the tokens in the first place. + + Args: + expert_output (torch.Tensor): Output from InferenceGroupedMLP, + shape [global_tokens, hidden_dim]. + + Returns: + torch.Tensor: The input tensor unchanged. + """ return expert_output def token_combine(self, hidden_states): - """ - Combines expert outputs using Reduce-Scatter. + """Combines expert outputs across EP ranks using Reduce-Scatter. - Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs with BF16 - when symmetric memory is available. Falls back to NCCL via superclass otherwise. + Reduces the global expert output (summing contributions from each rank) + and scatters the result so each rank receives its local token slice. + Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs + with BF16 when symmetric memory is available. Falls back to NCCL otherwise. Args: - hidden_states: [global_tokens, hidden_dim] tensor to reduce-scatter + hidden_states (torch.Tensor): Combined expert output after routing + weights have been applied, shape [global_tokens, hidden_dim]. Returns: - [local_tokens, hidden_dim] tensor after reduce-scatter + torch.Tensor: Local slice of the reduced output, + shape [local_tokens, hidden_dim] where + local_tokens = global_tokens // ep_size. """ if self.ep_size == 1: return hidden_states From fa6163373857f4648f584563b627095b846b0cc6 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Mar 2026 12:59:46 -0800 Subject: [PATCH 90/92] working histogram kernel --- .../transformer/moe/moe_inference_utils.py | 166 +++++++++++++++ .../moe/test_moe_inference_utils.py | 198 ++++++++++++++++++ 2 files changed, 364 insertions(+) create mode 100644 megatron/core/transformer/moe/moe_inference_utils.py create mode 100644 tests/unit_tests/transformer/moe/test_moe_inference_utils.py diff --git a/megatron/core/transformer/moe/moe_inference_utils.py b/megatron/core/transformer/moe/moe_inference_utils.py new file mode 100644 index 00000000000..0ee0bbd4355 --- /dev/null +++ b/megatron/core/transformer/moe/moe_inference_utils.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +Triton kernels for CUDA-graph-compatible MoE token permutation and unpermutation. + +These kernels enable the torch grouped GEMM path to work under CUDA graphs +by keeping all metadata (tokens_per_expert, permutation indices) GPU-resident. +""" + +from unittest.mock import MagicMock + +import torch +from packaging import version + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + if version.parse(triton.__version__) < version.parse("3.4.0") and not torch.cuda.is_available(): + HAVE_TRITON = False + else: + HAVE_TRITON = tl.constexpr(version.parse(triton.__version__) >= version.parse("2.0.0")) +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + + +# --------------------------------------------------------------------------- # +# Kernel: Count tokens per local expert +# --------------------------------------------------------------------------- # +@triton.jit +def _count_local_tokens_kernel( + routing_map_ptr, # [num_tokens, topk] - global expert IDs + tokens_per_expert_ptr, # [num_local_experts] output (must be zero-initialized) + total_pairs, + local_expert_start, + num_local_experts: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Count tokens assigned to each local expert, filtering out non-local experts. + + Each program handles BLOCK_SIZE (token, k) pairs from the routing_map. + Pairs whose assigned expert is not on this rank are ignored. For local + experts, atomically increments the corresponding tokens_per_expert counter. + """ + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_pairs + + expert_ids = tl.load(routing_map_ptr + offsets, mask=mask, other=-1) + local_ids = expert_ids - local_expert_start + is_local = (local_ids >= 0) & (local_ids < num_local_experts) & mask + + # Scatter atomic add: each element adds 1 to its expert's counter + tl.atomic_add(tokens_per_expert_ptr + local_ids, 1, mask=is_local) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def compute_local_tokens_per_expert( + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, +) -> torch.Tensor: + """Count tokens routed to each local expert, filtering out non-local assignments. + + Scans the routing_map for (token, k) pairs whose assigned expert lives on + this rank (global ID in [local_expert_start, local_expert_start + num_local_experts)). + Pairs routed to experts on other ranks are ignored. + + Args: + routing_map (torch.Tensor): Expert assignments, shape [num_tokens, topk]. + Contains global expert IDs. + local_expert_start (int): First global expert index on this rank. + num_local_experts (int): Number of experts on this rank. + + Returns: + torch.Tensor: tokens_per_expert, shape [num_local_experts], dtype int32. + Count of (token, k) pairs assigned to each local expert. + """ + total_pairs = routing_map.numel() + + tokens_per_expert = torch.zeros( + num_local_experts, dtype=torch.int32, device=routing_map.device + ) + + HIST_BLOCK = 256 + hist_grid = ((total_pairs + HIST_BLOCK - 1) // HIST_BLOCK,) + _count_local_tokens_kernel[hist_grid]( + routing_map, + tokens_per_expert, + total_pairs, + local_expert_start, + num_local_experts, + BLOCK_SIZE=HIST_BLOCK, + ) + + return tokens_per_expert + + +if __name__ == "__main__": + torch.manual_seed(42) + + # --- Config --- + num_tokens = 128 + topk = 8 + num_total_experts = 64 + num_local_experts = 8 + local_expert_start = 16 # this rank owns experts 16..23 + + # --- Build a random routing_map with global expert IDs --- + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + # --- Reference: count with PyTorch --- + flat = routing_map.flatten() + local_mask = (flat >= local_expert_start) & (flat < local_expert_start + num_local_experts) + local_ids_ref = flat[local_mask] - local_expert_start + ref = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + ref.scatter_add_(0, local_ids_ref.long(), torch.ones_like(local_ids_ref, dtype=torch.int32)) + + # --- Triton kernel --- + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + + # --- Compare --- + print(f"Reference: {ref.tolist()}") + print(f"Triton: {result.tolist()}") + assert torch.equal(ref, result), f"MISMATCH!\n ref={ref}\n got={result}" + print("PASSED - histogram matches reference") + + # --- Edge cases --- + # All tokens routed to non-local experts + routing_map_none = torch.zeros( + num_tokens, topk, dtype=torch.int32, device="cuda" + ) # expert 0, not in [16..23] + result_none = compute_local_tokens_per_expert(routing_map_none, local_expert_start, num_local_experts) + assert torch.equal(result_none, torch.zeros(num_local_experts, dtype=torch.int32, device="cuda")) + print("PASSED - no local experts case") + + # All tokens routed to a single local expert + routing_map_single = torch.full( + (num_tokens, topk), local_expert_start + 3, dtype=torch.int32, device="cuda" + ) + result_single = compute_local_tokens_per_expert(routing_map_single, local_expert_start, num_local_experts) + expected_single = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + expected_single[3] = num_tokens * topk + assert torch.equal(result_single, expected_single) + print("PASSED - single expert case") + + # Small: 1 token, topk=1 + routing_map_tiny = torch.tensor([[local_expert_start]], dtype=torch.int32, device="cuda") + result_tiny = compute_local_tokens_per_expert(routing_map_tiny, local_expert_start, num_local_experts) + expected_tiny = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + expected_tiny[0] = 1 + assert torch.equal(result_tiny, expected_tiny) + print("PASSED - single token case") + + print("\nAll tests passed.") diff --git a/tests/unit_tests/transformer/moe/test_moe_inference_utils.py b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py new file mode 100644 index 00000000000..a331e9f868c --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py @@ -0,0 +1,198 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for megatron.core.transformer.moe.moe_inference_utils. + +Tests the compute_local_tokens_per_expert triton kernel against a PyTorch +reference implementation across various routing configurations. +""" + +import pytest +import torch + +from megatron.core.transformer.moe.moe_inference_utils import compute_local_tokens_per_expert + + +def _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts): + """PyTorch reference: count (token, k) pairs routed to each local expert.""" + flat = routing_map.flatten() + local_mask = (flat >= local_expert_start) & (flat < local_expert_start + num_local_experts) + local_ids = flat[local_mask] - local_expert_start + ref = torch.zeros(num_local_experts, dtype=torch.int32, device=routing_map.device) + ref.scatter_add_(0, local_ids.long(), torch.ones_like(local_ids, dtype=torch.int32)) + return ref + + +class TestComputeLocalTokensPerExpert: + """Tests for the _count_local_tokens_kernel triton kernel.""" + + @pytest.mark.internal + def test_random_routing(self): + """Random routing_map should match PyTorch reference.""" + torch.manual_seed(42) + num_tokens, topk = 128, 8 + num_total_experts, num_local_experts = 64, 8 + local_expert_start = 16 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_no_local_experts(self): + """All tokens routed to non-local experts should give all zeros.""" + num_tokens, topk = 64, 4 + num_local_experts = 8 + local_expert_start = 16 + + # Expert 0 is not in [16..23] + routing_map = torch.zeros(num_tokens, topk, dtype=torch.int32, device="cuda") + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_all_to_single_expert(self): + """All tokens routed to one local expert.""" + num_tokens, topk = 64, 4 + num_local_experts = 8 + local_expert_start = 16 + target_local_idx = 3 + + routing_map = torch.full( + (num_tokens, topk), + local_expert_start + target_local_idx, + dtype=torch.int32, + device="cuda", + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + expected[target_local_idx] = num_tokens * topk + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_single_token(self): + """Minimal case: 1 token, topk=1.""" + num_local_experts = 8 + local_expert_start = 16 + + routing_map = torch.tensor( + [[local_expert_start]], dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + expected[0] = 1 + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_uniform_distribution(self): + """Each token routes to all local experts exactly once (topk == num_local_experts).""" + num_tokens = 32 + num_local_experts = 8 + local_expert_start = 0 + + # Each row is [0, 1, 2, ..., 7] — one hit per local expert per token + routing_map = ( + torch.arange(num_local_experts, device="cuda", dtype=torch.int32) + .unsqueeze(0) + .expand(num_tokens, -1) + .contiguous() + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.full( + (num_local_experts,), num_tokens, dtype=torch.int32, device="cuda" + ) + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_local_expert_start_at_zero(self): + """Local experts starting at global index 0.""" + torch.manual_seed(123) + num_tokens, topk = 256, 4 + num_total_experts = 32 + num_local_experts = 4 + local_expert_start = 0 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_local_experts_at_end(self): + """Local experts at the tail end of the global expert range.""" + torch.manual_seed(456) + num_tokens, topk = 256, 4 + num_total_experts = 32 + num_local_experts = 4 + local_expert_start = 28 # experts 28..31 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_large_batch(self): + """Larger batch to exercise multi-block histogram kernel.""" + torch.manual_seed(789) + num_tokens, topk = 2048, 8 + num_total_experts = 128 + num_local_experts = 16 + local_expert_start = 48 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_topk_one(self): + """topk=1: each token assigned to exactly one expert.""" + torch.manual_seed(101) + num_tokens = 512 + num_total_experts = 64 + num_local_experts = 8 + local_expert_start = 8 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, 1), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_non_power_of_two_tokens(self): + """Non-power-of-2 num_tokens to check masking at block boundaries.""" + torch.manual_seed(202) + num_tokens, topk = 137, 5 + num_total_experts = 32 + num_local_experts = 4 + local_expert_start = 12 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) From 86b438ff48f473b13ea464bcaa0372dadd83699f Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Mar 2026 13:21:47 -0800 Subject: [PATCH 91/92] add exhaustive unit tests --- .../transformer/moe/moe_inference_utils.py | 119 +++++++------- .../moe/test_moe_inference_utils.py | 145 +++++++++++++++++- 2 files changed, 206 insertions(+), 58 deletions(-) diff --git a/megatron/core/transformer/moe/moe_inference_utils.py b/megatron/core/transformer/moe/moe_inference_utils.py index 0ee0bbd4355..1c26c04c19c 100644 --- a/megatron/core/transformer/moe/moe_inference_utils.py +++ b/megatron/core/transformer/moe/moe_inference_utils.py @@ -105,62 +105,67 @@ def compute_local_tokens_per_expert( return tokens_per_expert -if __name__ == "__main__": - torch.manual_seed(42) - - # --- Config --- - num_tokens = 128 - topk = 8 - num_total_experts = 64 - num_local_experts = 8 - local_expert_start = 16 # this rank owns experts 16..23 - - # --- Build a random routing_map with global expert IDs --- - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) +# --------------------------------------------------------------------------- # +# Kernel: Exclusive prefix sum + atomic counters (single block) +# --------------------------------------------------------------------------- # +@triton.jit +def _prefix_sum_kernel( + tokens_per_expert_ptr, # [num_local_experts] input + expert_offsets_ptr, # [num_local_experts] output + atomic_counters_ptr, # [num_local_experts] output (copy of offsets) + num_local_experts, + BLOCK_SIZE: tl.constexpr, # next_power_of_2(num_local_experts) +): + """Compute exclusive prefix sum of tokens_per_expert. + + Runs as a single block. Reads tokens_per_expert, computes exclusive prefix + sum via tl.cumsum, and writes expert_offsets and a copy as atomic_counters + for use by the permute kernel. + """ + expert_range = tl.arange(0, BLOCK_SIZE) + mask = expert_range < num_local_experts + histogram = tl.load(tokens_per_expert_ptr + expert_range, mask=mask, other=0) - # --- Reference: count with PyTorch --- - flat = routing_map.flatten() - local_mask = (flat >= local_expert_start) & (flat < local_expert_start + num_local_experts) - local_ids_ref = flat[local_mask] - local_expert_start - ref = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") - ref.scatter_add_(0, local_ids_ref.long(), torch.ones_like(local_ids_ref, dtype=torch.int32)) - - # --- Triton kernel --- - result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) - - # --- Compare --- - print(f"Reference: {ref.tolist()}") - print(f"Triton: {result.tolist()}") - assert torch.equal(ref, result), f"MISMATCH!\n ref={ref}\n got={result}" - print("PASSED - histogram matches reference") - - # --- Edge cases --- - # All tokens routed to non-local experts - routing_map_none = torch.zeros( - num_tokens, topk, dtype=torch.int32, device="cuda" - ) # expert 0, not in [16..23] - result_none = compute_local_tokens_per_expert(routing_map_none, local_expert_start, num_local_experts) - assert torch.equal(result_none, torch.zeros(num_local_experts, dtype=torch.int32, device="cuda")) - print("PASSED - no local experts case") - - # All tokens routed to a single local expert - routing_map_single = torch.full( - (num_tokens, topk), local_expert_start + 3, dtype=torch.int32, device="cuda" + # Inclusive prefix sum, then shift to exclusive + inclusive = tl.cumsum(histogram, axis=0) + exclusive = inclusive - histogram + + tl.store(expert_offsets_ptr + expert_range, exclusive, mask=mask) + tl.store(atomic_counters_ptr + expert_range, exclusive, mask=mask) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def compute_expert_offsets( + tokens_per_expert: torch.Tensor, +) -> tuple: + """Compute exclusive prefix sum of tokens_per_expert and a mutable copy for atomics. + + Args: + tokens_per_expert (torch.Tensor): Token counts per local expert, + shape [num_local_experts], dtype int32. + + Returns: + tuple: (expert_offsets, atomic_counters) where: + - expert_offsets: [num_local_experts] exclusive prefix sum (read-only). + - atomic_counters: [num_local_experts] same values as expert_offsets, + to be mutated by the permute kernel's atomic adds. + """ + num_local_experts = tokens_per_expert.shape[0] + + expert_offsets = torch.empty_like(tokens_per_expert) + atomic_counters = torch.empty_like(tokens_per_expert) + + BLOCK_SIZE = triton.next_power_of_2(num_local_experts) + _prefix_sum_kernel[(1,)]( + tokens_per_expert, + expert_offsets, + atomic_counters, + num_local_experts, + BLOCK_SIZE=BLOCK_SIZE, ) - result_single = compute_local_tokens_per_expert(routing_map_single, local_expert_start, num_local_experts) - expected_single = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") - expected_single[3] = num_tokens * topk - assert torch.equal(result_single, expected_single) - print("PASSED - single expert case") - - # Small: 1 token, topk=1 - routing_map_tiny = torch.tensor([[local_expert_start]], dtype=torch.int32, device="cuda") - result_tiny = compute_local_tokens_per_expert(routing_map_tiny, local_expert_start, num_local_experts) - expected_tiny = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") - expected_tiny[0] = 1 - assert torch.equal(result_tiny, expected_tiny) - print("PASSED - single token case") - - print("\nAll tests passed.") + + return expert_offsets, atomic_counters + + diff --git a/tests/unit_tests/transformer/moe/test_moe_inference_utils.py b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py index a331e9f868c..710143a846a 100644 --- a/tests/unit_tests/transformer/moe/test_moe_inference_utils.py +++ b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py @@ -10,7 +10,10 @@ import pytest import torch -from megatron.core.transformer.moe.moe_inference_utils import compute_local_tokens_per_expert +from megatron.core.transformer.moe.moe_inference_utils import ( + compute_expert_offsets, + compute_local_tokens_per_expert, +) def _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts): @@ -196,3 +199,143 @@ def test_non_power_of_two_tokens(self): result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) assert torch.equal(result, ref) + + +class TestComputeExpertOffsets: + """Tests for the _prefix_sum_kernel triton kernel.""" + + @pytest.mark.internal + def test_basic_prefix_sum(self): + """Simple known input: [3, 1, 4, 1, 5] -> offsets [0, 3, 4, 8, 9].""" + tokens_per_expert = torch.tensor([3, 1, 4, 1, 5], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.tensor([0, 3, 4, 8, 9], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + assert torch.equal(counters, expected) + + @pytest.mark.internal + def test_single_expert(self): + """Single expert: offset is always 0.""" + tokens_per_expert = torch.tensor([42], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.tensor([0], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + assert torch.equal(counters, expected) + + @pytest.mark.internal + def test_all_zeros(self): + """All experts have zero tokens: offsets are all 0.""" + tokens_per_expert = torch.zeros(8, dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.zeros(8, dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + + @pytest.mark.internal + def test_one_hot(self): + """Only one expert has tokens.""" + tokens_per_expert = torch.tensor([0, 0, 10, 0, 0], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.tensor([0, 0, 0, 10, 10], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + + @pytest.mark.internal + def test_counters_are_independent_copy(self): + """Mutating counters should not affect offsets.""" + tokens_per_expert = torch.tensor([2, 3, 5], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + # Simulate what the permute kernel does + counters[0] += 1 + counters[1] += 2 + + expected_offsets = torch.tensor([0, 2, 5], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected_offsets), "offsets should be unmodified" + + @pytest.mark.internal + def test_end_to_end_with_histogram(self): + """Chain histogram -> prefix sum and verify consistency.""" + torch.manual_seed(99) + num_tokens, topk = 128, 8 + num_total_experts = 64 + num_local_experts = 8 + local_expert_start = 16 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + tokens_per_expert = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + offsets, counters = compute_expert_offsets(tokens_per_expert) + + # Verify: offsets[i] == sum(tokens_per_expert[0:i]) + ref_offsets = torch.zeros_like(tokens_per_expert) + ref_offsets[1:] = torch.cumsum(tokens_per_expert[:-1], dim=0) + assert torch.equal(offsets, ref_offsets) + + # Verify: last offset + last count == total + assert offsets[-1] + tokens_per_expert[-1] == tokens_per_expert.sum() + + +class TestComputeLocalTokensPerExpertSweep: + """Exhaustive parametrized sweep for _count_local_tokens_kernel.""" + + @pytest.mark.internal + @pytest.mark.parametrize("num_tokens", [1, 7, 32, 137, 256, 512, 2048]) + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("num_local_experts", [1, 3, 4, 8, 16]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_sweep(self, num_tokens, topk, num_local_experts, num_total_experts): + """Sweep across token counts, topk, and expert configurations.""" + if num_local_experts > num_total_experts: + pytest.skip("num_local_experts > num_total_experts") + + torch.manual_seed(num_tokens * 1000 + topk * 100 + num_local_experts) + local_expert_start = (num_total_experts - num_local_experts) // 2 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref), ( + f"Mismatch for num_tokens={num_tokens}, topk={topk}, " + f"num_local_experts={num_local_experts}, num_total_experts={num_total_experts}" + ) + + +class TestComputeExpertOffsetsSweep: + """Exhaustive parametrized sweep for _prefix_sum_kernel.""" + + @pytest.mark.internal + @pytest.mark.parametrize("num_local_experts", [1, 2, 3, 4, 5, 7, 8, 13, 16, 32, 64]) + @pytest.mark.parametrize("max_count", [0, 1, 10, 100, 1000]) + def test_sweep(self, num_local_experts, max_count): + """Sweep across expert counts and token magnitudes.""" + torch.manual_seed(num_local_experts * 100 + max_count) + + if max_count == 0: + tokens_per_expert = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + else: + tokens_per_expert = torch.randint( + 0, max_count + 1, (num_local_experts,), dtype=torch.int32, device="cuda" + ) + + offsets, counters = compute_expert_offsets(tokens_per_expert) + + # Reference: exclusive prefix sum via torch.cumsum + ref_offsets = torch.zeros_like(tokens_per_expert) + if num_local_experts > 1: + ref_offsets[1:] = torch.cumsum(tokens_per_expert[:-1], dim=0) + + assert torch.equal(offsets, ref_offsets), ( + f"Mismatch for num_local_experts={num_local_experts}, max_count={max_count}\n" + f" tokens_per_expert={tokens_per_expert.tolist()}\n" + f" expected={ref_offsets.tolist()}\n" + f" got={offsets.tolist()}" + ) + assert torch.equal(counters, ref_offsets), "counters should equal offsets initially" From b7821ac66a40e26f9f05ec99f5ec063cfa932376 Mon Sep 17 00:00:00 2001 From: Siddharth Singh Date: Mon, 2 Mar 2026 17:27:41 -0800 Subject: [PATCH 92/92] work with qwen3 on hopper --- megatron/core/models/gpt/gpt_layer_specs.py | 24 +- megatron/core/transformer/enums.py | 8 + megatron/core/transformer/moe/experts.py | 64 ++-- .../transformer/moe/moe_inference_utils.py | 265 ++++++++++++++- megatron/core/transformer/moe/moe_layer.py | 36 +- .../moe/token_dispatcher_inference.py | 103 +++++- .../core/transformer/transformer_config.py | 56 +++- .../moe/test_moe_inference_utils.py | 314 ++++++++++++++++-- 8 files changed, 762 insertions(+), 108 deletions(-) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index aae2d5f3e81..a9a40660f98 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -8,7 +8,10 @@ InferenceSpecProvider, LocalSpecProvider, ) -from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend +from megatron.core.models.gpt.moe_module_specs import ( + get_inference_optimized_moe_spec, + get_moe_module_spec_for_backend, +) from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType, LayerType from megatron.core.transformer.identity_op import IdentityOp @@ -89,14 +92,17 @@ def get_gpt_layer_with_inference_submodules( assert HAVE_TE, "--transformer-impl inference_optimized requires transformer engine" backend = InferenceSpecProvider() - mlp = get_mlp_module_spec_for_backend( - backend=backend, - num_experts=num_experts, - moe_grouped_gemm=moe_grouped_gemm, - moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, - use_te_op_fuser=False, - use_te_activation_func=False, - ) + if num_experts is not None: + mlp = get_inference_optimized_moe_spec() + else: + mlp = get_mlp_module_spec_for_backend( + backend=backend, + num_experts=None, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + use_te_op_fuser=False, + use_te_activation_func=False, + ) if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py index d57e24887ab..ead3aef6cec 100644 --- a/megatron/core/transformer/enums.py +++ b/megatron/core/transformer/enums.py @@ -67,6 +67,14 @@ class AttnBackend(enum.Enum): auto = 5 +class MoEGroupedGemmBackend(enum.Enum): + """Backend for MoE grouped GEMM operations.""" + + te = 1 # Transformer Engine GroupedGEMM + torch = 2 # torch._grouped_mm + flashinfer = 3 # FlashInfer fused cutlass_fused_moe kernel + + class CudaGraphScope(enum.Enum): """Cuda Graph Scope - defines which parts of the model to capture.""" diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 0fc954db4af..8eda2d3d5a4 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -42,6 +42,7 @@ TEActivationFunctionBuilder, apply_swiglu_sharded_factory, ) +from megatron.core.transformer.enums import MoEGroupedGemmBackend from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe import grouped_gemm_util as gg from megatron.core.transformer.moe.moe_utils import ( @@ -741,7 +742,7 @@ def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): original_dtype = intermediate_parallel.dtype intermediate_parallel = intermediate_parallel * permuted_probs intermediate_parallel = intermediate_parallel.to(original_dtype) - elif self.config.bias_activation_fusion: + elif self.config.bias_activation_fusion and permuted_probs is not None: if self.activation_func == F.silu and self.config.gated_linear_unit: # dtype is handled inside the fused kernel intermediate_parallel = weighted_bias_swiglu_impl( @@ -761,7 +762,7 @@ def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): ) else: raise ValueError("Only support fusion of swiglu and quick_gelu in TEGroupedMLP.") - elif self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu: + elif self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu and permuted_probs is not None: assert bias_parallel is None, "Bias is not supported with fused weighted squared relu." intermediate_parallel = weighted_squared_relu_impl( intermediate_parallel, permuted_probs @@ -782,7 +783,8 @@ def glu(x): else: intermediate_parallel = self.activation_func(intermediate_parallel) original_dtype = intermediate_parallel.dtype - intermediate_parallel = intermediate_parallel * permuted_probs + if permuted_probs is not None: + intermediate_parallel = intermediate_parallel * permuted_probs intermediate_parallel = intermediate_parallel.to(original_dtype) return intermediate_parallel @@ -957,7 +959,6 @@ def __init__( self._torch_grouped_mm_available = ( is_torch_min_version("2.10") and hasattr(torch, '_grouped_mm') - and not config.inference_disable_torch_grouped_mm ) if HAVE_FLASHINFER: @@ -1046,13 +1047,15 @@ def _flashinfer_forward(self, hidden_states, routing_map, probs): return output, None def _torch_grouped_mm_forward( - self, permuted_local_hidden_states, tokens_per_expert, permuted_probs + self, permuted_local_hidden_states, tokens_per_expert, permuted_probs, + inclusive_expert_offsets=None, ): - permuted_probs = permuted_probs.unsqueeze(-1) + if permuted_probs is not None: + permuted_probs = permuted_probs.unsqueeze(-1) if not tokens_per_expert.is_cuda: tokens_per_expert = tokens_per_expert.to('cuda') - if self.config.moe_apply_probs_on_input: + if self.config.moe_apply_probs_on_input and permuted_probs is not None: assert ( self.config.moe_router_topk == 1 ), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1." @@ -1062,12 +1065,12 @@ def _torch_grouped_mm_forward( permuted_probs = torch.ones_like(permuted_probs) if permuted_local_hidden_states.nelement() != 0: - # Use pre-concatenated weights (built during init/load) - # _fc1_weight shape: [num_experts, ffn_hidden * (2 if gated else 1), hidden_size] - # _fc2_weight shape: [num_experts, hidden_size, ffn_hidden] - # Compute cumulative offsets on GPU (no host sync!) - # offs[i] = end index of expert i's tokens - offs = tokens_per_expert.cumsum(0).to(torch.int32) + # Reuse precomputed inclusive offsets from the dispatcher if available, + # otherwise compute cumsum on the fly (non-CG eager path). + if inclusive_expert_offsets is not None: + offs = inclusive_expert_offsets + else: + offs = tokens_per_expert.cumsum(0).to(torch.int32) fc1_output = torch._grouped_mm( permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs @@ -1091,6 +1094,7 @@ def forward( tokens_per_expert: Optional[torch.Tensor], permuted_probs: torch.Tensor, routing_map: Optional[torch.Tensor] = None, + inclusive_expert_offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Forward pass with three modes: @@ -1107,20 +1111,38 @@ def forward( permuted_probs: [num_tokens, topk] routing probabilities. routing_map: [num_tokens, topk] token-to-expert assignment indices. Required for the FlashInfer CUDA-graphed path, None otherwise. + inclusive_expert_offsets: [num_experts] precomputed inclusive cumsum of + tokens_per_expert (i.e. offs[i] = end index of expert i). When + provided, torch._grouped_mm reuses these directly instead of + recomputing cumsum. None for FlashInfer and non-CG paths. """ if self.training: return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) elif self.is_inference_cuda_graphed_iteration: - assert routing_map is not None, "routing_map is required for FlashInfer forward pass." - assert ( - HAVE_FLASHINFER - ), "FlashInfer is not available; cannot use FlashInfer forward pass." - return self._flashinfer_forward( - permuted_local_hidden_states, routing_map, permuted_probs - ) + if self.config.moe_ggemm_inference_cg == MoEGroupedGemmBackend.flashinfer: + assert routing_map is not None, ( + "routing_map is required for FlashInfer forward pass." + ) + assert HAVE_FLASHINFER, ( + "FlashInfer is not available; cannot use FlashInfer forward pass." + ) + return self._flashinfer_forward( + permuted_local_hidden_states, routing_map, permuted_probs + ) + else: + assert tokens_per_expert is not None, ( + "tokens_per_expert is required for torch grouped_mm forward pass." + ) + return self._torch_grouped_mm_forward( + permuted_local_hidden_states, tokens_per_expert, permuted_probs, + inclusive_expert_offsets=inclusive_expert_offsets, + ) - elif self._torch_grouped_mm_available: + elif ( + self.config.moe_ggemm_inference_no_cg == MoEGroupedGemmBackend.torch + and self._torch_grouped_mm_available + ): return self._torch_grouped_mm_forward( permuted_local_hidden_states, tokens_per_expert, permuted_probs ) diff --git a/megatron/core/transformer/moe/moe_inference_utils.py b/megatron/core/transformer/moe/moe_inference_utils.py index 1c26c04c19c..89726be8929 100644 --- a/megatron/core/transformer/moe/moe_inference_utils.py +++ b/megatron/core/transformer/moe/moe_inference_utils.py @@ -106,21 +106,19 @@ def compute_local_tokens_per_expert( # --------------------------------------------------------------------------- # -# Kernel: Exclusive prefix sum + atomic counters (single block) +# Kernel: Exclusive prefix sum (single block) # --------------------------------------------------------------------------- # @triton.jit def _prefix_sum_kernel( tokens_per_expert_ptr, # [num_local_experts] input expert_offsets_ptr, # [num_local_experts] output - atomic_counters_ptr, # [num_local_experts] output (copy of offsets) num_local_experts, BLOCK_SIZE: tl.constexpr, # next_power_of_2(num_local_experts) ): """Compute exclusive prefix sum of tokens_per_expert. Runs as a single block. Reads tokens_per_expert, computes exclusive prefix - sum via tl.cumsum, and writes expert_offsets and a copy as atomic_counters - for use by the permute kernel. + sum via tl.cumsum, and writes expert_offsets. """ expert_range = tl.arange(0, BLOCK_SIZE) mask = expert_range < num_local_experts @@ -131,7 +129,6 @@ def _prefix_sum_kernel( exclusive = inclusive - histogram tl.store(expert_offsets_ptr + expert_range, exclusive, mask=mask) - tl.store(atomic_counters_ptr + expert_range, exclusive, mask=mask) # --------------------------------------------------------------------------- # @@ -139,33 +136,273 @@ def _prefix_sum_kernel( # --------------------------------------------------------------------------- # def compute_expert_offsets( tokens_per_expert: torch.Tensor, -) -> tuple: - """Compute exclusive prefix sum of tokens_per_expert and a mutable copy for atomics. +) -> torch.Tensor: + """Compute exclusive prefix sum of tokens_per_expert. Args: tokens_per_expert (torch.Tensor): Token counts per local expert, shape [num_local_experts], dtype int32. Returns: - tuple: (expert_offsets, atomic_counters) where: - - expert_offsets: [num_local_experts] exclusive prefix sum (read-only). - - atomic_counters: [num_local_experts] same values as expert_offsets, - to be mutated by the permute kernel's atomic adds. + torch.Tensor: expert_offsets, shape [num_local_experts]. + Exclusive prefix sum: expert_offsets[i] is the start index of + expert i's tokens in the permuted buffer. Passed to permute_tokens + which mutates it in-place via atomic adds, turning the exclusive + start offsets into inclusive end offsets (i.e. expert_offsets[i] + becomes the end index of expert i's tokens after permutation). """ num_local_experts = tokens_per_expert.shape[0] expert_offsets = torch.empty_like(tokens_per_expert) - atomic_counters = torch.empty_like(tokens_per_expert) BLOCK_SIZE = triton.next_power_of_2(num_local_experts) _prefix_sum_kernel[(1,)]( tokens_per_expert, expert_offsets, - atomic_counters, num_local_experts, BLOCK_SIZE=BLOCK_SIZE, ) - return expert_offsets, atomic_counters + return expert_offsets + + +# --------------------------------------------------------------------------- # +# Kernel: Permute tokens by expert assignment +# --------------------------------------------------------------------------- # +@triton.jit +def _permute_tokens_kernel( + # Input pointers + hidden_states_ptr, # [num_tokens, hidden_dim] + probs_ptr, # [num_tokens, topk] + routing_map_ptr, # [num_tokens, topk] + # Output pointers + permuted_hidden_ptr, # [output_size, hidden_dim] + permuted_probs_ptr, # [output_size] + source_token_indices_ptr, # [output_size] + # Atomic counters (mutated in-place) + atomic_counters_ptr, # [num_local_experts] + # Dimensions + num_tokens, + hidden_dim, + topk: tl.constexpr, + local_expert_start, + num_local_experts: tl.constexpr, + BLOCK_H: tl.constexpr, # tile size for hidden_dim copy loop +): + """Permute tokens into expert-grouped order. + + Each program handles one (token, k) pair. If the assigned expert is local, + it atomically claims a write position and copies the token's hidden state, + routing probability, and source token index to the output buffers. + The hidden dimension is copied in tiles of BLOCK_H to support large hidden sizes. + """ + pair_idx = tl.program_id(0) + token_idx = pair_idx // topk + k_idx = pair_idx % topk + + if token_idx >= num_tokens: + return + + expert_id = tl.load(routing_map_ptr + token_idx * topk + k_idx) + local_idx = expert_id - local_expert_start + + if local_idx < 0 or local_idx >= num_local_experts: + return + + # Atomically claim a write position + write_pos = tl.atomic_add(atomic_counters_ptr + local_idx, 1) + + # Copy hidden state row in tiles of BLOCK_H + src_row_ptr = hidden_states_ptr + token_idx * hidden_dim + dst_row_ptr = permuted_hidden_ptr + write_pos * hidden_dim + for h_start in tl.range(0, hidden_dim, BLOCK_H): + h_offsets = h_start + tl.arange(0, BLOCK_H) + h_mask = h_offsets < hidden_dim + src_vals = tl.load(src_row_ptr + h_offsets, mask=h_mask) + tl.store(dst_row_ptr + h_offsets, src_vals, mask=h_mask) + + # Copy the routing probability for this (token, k) pair + prob_val = tl.load(probs_ptr + token_idx * topk + k_idx) + tl.store(permuted_probs_ptr + write_pos, prob_val) + + # Record source token index for the unpermute kernel + tl.store(source_token_indices_ptr + write_pos, token_idx) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def permute_tokens( + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + expert_offsets: torch.Tensor, + local_expert_start: int, + num_local_experts: int, +) -> tuple: + """Permute tokens into expert-grouped order for the torch grouped GEMM. + + Scatters tokens into an output buffer where all tokens for expert 0 come + first, then expert 1, etc. Uses expert_offsets (from compute_expert_offsets) + to assign write positions. + + NOTE: expert_offsets is mutated in-place. On entry it contains exclusive + start offsets (expert_offsets[i] = start index of expert i's region). + The kernel atomically increments each entry as it places tokens, so on + exit expert_offsets[i] = end index (inclusive) of expert i's region. + The last entry equals the total number of routed tokens. + + Args: + hidden_states (torch.Tensor): Input hidden states, shape [num_tokens, hidden_dim]. + probs (torch.Tensor): Routing probabilities, shape [num_tokens, topk]. + routing_map (torch.Tensor): Expert assignments, shape [num_tokens, topk]. + Contains global expert IDs. + expert_offsets (torch.Tensor): Write position counters, shape [num_local_experts]. + Initialized to exclusive prefix sum by compute_expert_offsets. + Mutated in-place to inclusive end offsets by the permute kernel. + local_expert_start (int): First global expert index on this rank. + num_local_experts (int): Number of experts on this rank. + + Returns: + tuple: (permuted_hidden_states, permuted_probs, source_token_indices) where: + - permuted_hidden_states: [output_size, hidden_dim] tokens grouped by expert. + - permuted_probs: [output_size] scalar prob per permuted slot. + - source_token_indices: [output_size] original token index per permuted slot. + output_size = num_tokens * min(topk, num_local_experts). + Slots beyond the actual routed token count contain uninitialized data. + """ + num_tokens, hidden_dim = hidden_states.shape + topk = probs.shape[1] + output_size = num_tokens * min(topk, num_local_experts) + + # Allocate output buffers (statically sized for CUDA graph compatibility) + permuted_hidden = torch.empty( + output_size, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device + ) + permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) + source_token_indices = torch.empty(output_size, dtype=torch.int32, device=probs.device) + + total_pairs = num_tokens * topk + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + # After this kernel, expert_offsets is mutated: exclusive start offsets + # become inclusive end offsets (expert_offsets[-1] = total routed tokens). + _permute_tokens_kernel[(total_pairs,)]( + hidden_states, + probs, + routing_map, + permuted_hidden, + permuted_probs, + source_token_indices, + expert_offsets, + num_tokens, + hidden_dim, + topk, + local_expert_start, + num_local_experts, + BLOCK_H=BLOCK_H, + ) + + return permuted_hidden, permuted_probs, source_token_indices + + +# --------------------------------------------------------------------------- # +# Kernel: Unpermute (accumulate expert outputs back to token positions) +# --------------------------------------------------------------------------- # +@triton.jit +def _unpermute_tokens_kernel( + # Input pointers + expert_output_ptr, # [output_size, hidden_dim] + permuted_probs_ptr, # [output_size] + source_token_indices_ptr, # [output_size] + # GPU-resident valid count (read from last atomic counter after permute) + num_routed_slots_ptr, # scalar tensor on GPU + # Output pointer + output_ptr, # [num_tokens, hidden_dim] - must be zero-initialized + # Dimensions + hidden_dim, + BLOCK_H: tl.constexpr, # tile size for hidden_dim loop +): + """Accumulate weighted expert outputs back into original token positions. + + Each program handles one row of the permuted expert output. It reads the + source token index, multiplies the expert output by the routing probability, + and atomically adds the result to the corresponding row in the output buffer. + Multiple experts contributing to the same token are summed via atomic adds. + Rows beyond num_routed_slots (read from GPU) are skipped. + The hidden dimension is processed in tiles of BLOCK_H to support large hidden sizes. + """ + row_idx = tl.program_id(0) + + # Read valid count from GPU — no host sync needed for CUDA graphability + num_valid = tl.load(num_routed_slots_ptr) + if row_idx >= num_valid: + return + + token_idx = tl.load(source_token_indices_ptr + row_idx) + prob = tl.load(permuted_probs_ptr + row_idx) + + src_row_ptr = expert_output_ptr + row_idx * hidden_dim + dst_row_ptr = output_ptr + token_idx * hidden_dim + for h_start in tl.range(0, hidden_dim, BLOCK_H): + h_offsets = h_start + tl.arange(0, BLOCK_H) + h_mask = h_offsets < hidden_dim + expert_vals = tl.load(src_row_ptr + h_offsets, mask=h_mask) + scaled_vals = expert_vals * prob + tl.atomic_add(dst_row_ptr + h_offsets, scaled_vals, mask=h_mask) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def unpermute_tokens( + expert_output: torch.Tensor, + permuted_probs: torch.Tensor, + source_token_indices: torch.Tensor, + num_tokens: int, + num_routed_slots: torch.Tensor, +) -> torch.Tensor: + """Unpermute expert outputs back to original token order. + + Accumulates weighted expert outputs into the original token positions. + For each valid permuted row i, computes: + output[source_token_indices[i], :] += expert_output[i, :] * permuted_probs[i] + Multiple experts contributing to the same token are summed via atomic adds. + Rows beyond num_routed_slots are skipped. num_routed_slots is read from GPU + to avoid host synchronization, keeping the pipeline CUDA-graphable. + + Args: + expert_output (torch.Tensor): Expert outputs, shape [output_size, hidden_dim]. + Only the first total_routed rows are valid. + permuted_probs (torch.Tensor): Routing probs, shape [output_size]. + Only the first total_routed entries are valid. + source_token_indices (torch.Tensor): Source token index for each permuted row, + shape [output_size], dtype int32. Only the first total_routed are valid. + num_tokens (int): Number of original tokens (for output allocation). + num_routed_slots (torch.Tensor): 1-element GPU tensor containing the + total number of routed tokens. Must stay on GPU (no host sync). + + Returns: + torch.Tensor: Unpermuted output, shape [num_tokens, hidden_dim]. + Each row is the weighted sum of expert outputs for that token. + """ + output_size, hidden_dim = expert_output.shape + + # Zero-initialized output (atomic adds accumulate into this) + output = torch.zeros( + num_tokens, hidden_dim, dtype=expert_output.dtype, device=expert_output.device + ) + + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _unpermute_tokens_kernel[(output_size,)]( + expert_output, + permuted_probs, + source_token_indices, + num_routed_slots, + output, + hidden_dim, + BLOCK_H=BLOCK_H, + ) + + return output diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 8277486b03b..b60e2d6bb1d 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -25,6 +25,7 @@ MoEFlexTokenDispatcher, MoETokenDispatcher, ) +from megatron.core.transformer.enums import MoEGroupedGemmBackend from megatron.core.transformer.moe.token_dispatcher_inference import ( InferenceCUDAGraphTokenDispatcher, ) @@ -268,15 +269,16 @@ def __init__( # Inference-optimized mode setup if config.transformer_impl == "inference_optimized": - assert ( - HAVE_FLASHINFER - ), "flashinfer-python is required for inference-optimized MoE implementation." - if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: - warnings.warn( - "flashinfer-cubin and/or flashinfer-jit-cache not found. " - "The FlashInfer cutlass kernel will be JIT compiled," - "which may take a long time." - ) + if config.moe_ggemm_inference_cg == MoEGroupedGemmBackend.flashinfer: + assert ( + HAVE_FLASHINFER + ), "flashinfer-python is required when moe_ggemm_inference_cg='flashinfer'." + if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: + warnings.warn( + "flashinfer-cubin and/or flashinfer-jit-cache not found. " + "The FlashInfer cutlass kernel will be JIT compiled," + "which may take a long time." + ) self._setup_inference_mode(pg_collection) # Cudagraph tensor store for resuming the forward pass from the end of the cudagraph. @@ -418,10 +420,18 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso hasattr(self, "_inference_token_dispatcher") and self.is_inference_cuda_graphed_iteration ): - routing_map = self.token_dispatcher.routing_map - expert_output, mlp_bias = self.experts( - dispatched_input, tokens_per_expert, permuted_probs, routing_map=routing_map - ) + if self.config.moe_ggemm_inference_cg == MoEGroupedGemmBackend.flashinfer: + routing_map = self.token_dispatcher.routing_map + expert_output, mlp_bias = self.experts( + dispatched_input, tokens_per_expert, permuted_probs, routing_map=routing_map + ) + else: + # torch backend: dispatcher already permuted tokens; tokens_per_expert is set. + # Pass precomputed inclusive offsets to avoid recomputing cumsum. + expert_output, mlp_bias = self.experts( + dispatched_input, tokens_per_expert, permuted_probs, + inclusive_expert_offsets=self.token_dispatcher.inclusive_expert_offsets, + ) else: expert_output, mlp_bias = self.experts( dispatched_input, tokens_per_expert, permuted_probs diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 6b851c252c5..9bab4dbd4fa 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -26,6 +26,13 @@ gather_from_sequence_parallel_region, reduce_scatter_to_sequence_parallel_region, ) +from megatron.core.transformer.enums import MoEGroupedGemmBackend +from megatron.core.transformer.moe.moe_inference_utils import ( + compute_expert_offsets, + compute_local_tokens_per_expert, + permute_tokens, + unpermute_tokens, +) from megatron.core.transformer.moe.token_dispatcher import MoEAllGatherTokenDispatcher from megatron.core.transformer.transformer_config import TransformerConfig @@ -70,6 +77,20 @@ def __init__( self.triton_nvls_kernels_allowed = not self.config.inference_disable_triton_nvls_kernels + # Backend selection for CUDA-graphed grouped GEMM + self._moe_backend = config.moe_ggemm_inference_cg + self._local_expert_start = local_expert_indices[0] + + # Intermediate state for torch backend (set in dispatch_postprocess, used in combine_preprocess) + self._source_token_indices = None + self._num_routed_slots = None + self._permuted_probs = None + self._num_tokens_after_dispatch = None + # Inclusive end offsets per expert after permutation. Reused as the + # `offs` arg to torch._grouped_mm to avoid recomputing cumsum. + # None for the flashinfer backend (not needed). + self.inclusive_expert_offsets = None + def _maybe_allocate_ag_buffers( self, routing_map: torch.Tensor, probs: torch.Tensor, hidden_states: torch.Tensor ) -> dict: @@ -242,12 +263,11 @@ def token_dispatch(self, hidden_states, probs): return hidden_states, probs def dispatch_postprocess(self, hidden_states, probs): - """Pass-through: returns inputs directly without permutation. + """Post-process dispatched tokens for expert computation. - Unlike the training dispatcher, this does not permute tokens or compute - tokens_per_expert. The downstream InferenceGroupedMLP (FlashInfer / - CUTLASS fused MoE kernel) operates directly on the routing map stored - in self.routing_map. + For the flashinfer backend, this is a pass-through: the FlashInfer fused + kernel operates directly on self.routing_map. For the torch backend, + uses triton kernels to permute tokens into expert-grouped order. Args: hidden_states (torch.Tensor): Gathered hidden states, @@ -256,25 +276,79 @@ def dispatch_postprocess(self, hidden_states, probs): shape [global_tokens, topk]. Returns: - tuple: (hidden_states, tokens_per_expert, probs) where - tokens_per_expert is always None. + tuple: (hidden_states, tokens_per_expert, probs) where: + - flashinfer: tokens_per_expert=None, probs=original probs + - torch: tokens_per_expert=GPU histogram, probs=None + (probs deferred to combine_preprocess unpermute) """ - return hidden_states, None, probs + if self._moe_backend == MoEGroupedGemmBackend.flashinfer: + return hidden_states, None, probs + + # torch backend: permute tokens into expert-grouped order + self._num_tokens_after_dispatch = hidden_states.size(0) + tokens_per_expert = compute_local_tokens_per_expert( + self.routing_map, self._local_expert_start, self.num_local_experts + ) + exclusive_expert_offsets = compute_expert_offsets(tokens_per_expert) + # permute_tokens mutates the offsets tensor in-place via atomic adds, + # converting exclusive start offsets into inclusive end offsets. + # Example with 3 experts and tokens_per_expert = [2, 3, 1]: + # exclusive_expert_offsets = [0, 2, 5] (start of each expert's region) + # inclusive_expert_offsets = [2, 5, 6] (end of each expert's region) + # The last entry (6) equals the total number of routed tokens. + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, self.routing_map, exclusive_expert_offsets, + self._local_expert_start, self.num_local_experts, + ) + inclusive_expert_offsets = exclusive_expert_offsets # mutated in-place by permute_tokens + + # Cache state for combine_preprocess. + self._source_token_indices = source_token_indices + # Number of (token, expert) slots actually routed to local experts on + # this rank. With topk > 1, a single token may occupy multiple slots. + # Only permuted_hidden[:num_routed_slots] contains real data; the + # remaining rows are uninitialized padding for static CUDA graph shapes. + # Used later in combine_preprocess/unpermute_tokens to read only the + # valid grouped GEMM outputs and skip garbage-padded rows. + # Stored as a 1-element GPU tensor view ([-1:], not [-1]) to avoid a + # device-to-host sync that would break CUDA graphability. + self._num_routed_slots = inclusive_expert_offsets[-1:] + self._permuted_probs = permuted_probs + # Cache the full inclusive offsets so torch._grouped_mm can reuse them + # directly as `offs` without recomputing cumsum on tokens_per_expert. + self.inclusive_expert_offsets = inclusive_expert_offsets + + # probs=None: expert skips prob-weighting, unpermute applies probs instead + return permuted_hidden, tokens_per_expert, None def combine_preprocess(self, expert_output): - """Pass-through: InferenceGroupedMLP already produces unpermuted output. + """Unpermute expert outputs back to original token order. - No unpermutation is needed because dispatch_postprocess did not permute - the tokens in the first place. + For flashinfer, this is a pass-through (FlashInfer already produces + output in token order). For the torch backend, uses the triton unpermute + kernel to scatter-add expert outputs weighted by routing probs. Args: expert_output (torch.Tensor): Output from InferenceGroupedMLP, - shape [global_tokens, hidden_dim]. + shape [output_size, hidden_dim] (torch) or + [global_tokens, hidden_dim] (flashinfer). Returns: - torch.Tensor: The input tensor unchanged. + torch.Tensor: Output in original token order, + shape [global_tokens, hidden_dim]. """ - return expert_output + if self._moe_backend == MoEGroupedGemmBackend.flashinfer: + return expert_output + + # torch backend: unpermute with routing probs + output = unpermute_tokens( + expert_output, + self._permuted_probs, + self._source_token_indices, + self._num_tokens_after_dispatch, + self._num_routed_slots, + ) + return output def token_combine(self, hidden_states): """Combines expert outputs across EP ranks using Reduce-Scatter. @@ -324,3 +398,4 @@ def token_combine(self, hidden_states): hidden_states, group=self.tp_ep_group ) return hidden_states + diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 559f4226af2..95d9af30a01 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -11,7 +11,7 @@ from megatron.core.enums import Fp4Recipe, Fp8Recipe from megatron.core.quantization.quant_config import RecipeConfig -from megatron.core.transformer.enums import AttnBackend, CudaGraphScope +from megatron.core.transformer.enums import AttnBackend, CudaGraphScope, MoEGroupedGemmBackend from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout from .._rank_utils import log_single_rank @@ -917,9 +917,18 @@ class TransformerConfig(ModelParallelConfig): inference_disable_triton_nvls_kernels: bool = False """ If true, disables the use of Triton NVLS kernels during inference. """ - inference_disable_torch_grouped_mm: bool = False - """ If true, disables torch._grouped_mm in InferenceGroupedMLP, - falling back to TE GroupedGEMM. """ + moe_ggemm_training: str = "te" + """Backend for training grouped GEMM. Only 'te' is supported.""" + + moe_ggemm_inference_cg: str = "flashinfer" + """Backend for CUDA-graphed inference grouped GEMM. + 'flashinfer': FlashInfer's fused cutlass_fused_moe kernel (default). + 'torch': triton permute/unpermute kernels + torch._grouped_mm.""" + + moe_ggemm_inference_no_cg: str = "torch" + """Backend for non-CUDA-graphed inference grouped GEMM. + 'torch': torch._grouped_mm with GPU-resident cumsum offsets (default). + 'te': TE GroupedGEMM fallback.""" mrope_section: Optional[List[int]] = None """ Multimodal rope section is for channel dimension of temporal, height and width @@ -1159,17 +1168,23 @@ def __post_init__(self): "Inference-optimized MoE layers do not support padded " "routing map for quantization." ) - if self.moe_router_dtype != "fp32": + if self.moe_router_dtype != "fp32" and self.moe_ggemm_inference_cg == "flashinfer": raise ValueError( - "Inference-optimized MoE requires --moe-router-dtype=fp32 " - "to avoid costly dtype conversions during decode." + "The FlashInfer CUDA-graphed MoE backend requires " + "--moe-router-dtype=fp32. Either set --moe-router-dtype=fp32 " + "or switch to --moe-ggemm-inference-cg=torch." ) - if self.gated_linear_unit and self.cuda_graph_impl != "none": + if ( + self.gated_linear_unit + and self.cuda_graph_impl != "none" + and self.moe_ggemm_inference_cg == "flashinfer" + ): raise ValueError( - "Inference-optimized MoE does not yet support CUDA graphs with gated " + "The FlashInfer CUDA-graphed MoE backend does not support gated " "linear units (SwiGLU/GeGLU) due to differences in weight layouts " - "between the FlashInfer kernel and mcore. Either disable CUDA graphs " - "(--cuda-graph-impl=none) or use a non-gated activation (e.g. squared_relu)." + "between the FlashInfer kernel and mcore. Either switch to the torch " + "backend (--moe-ggemm-inference-cg=torch), disable CUDA graphs " + "(--cuda-graph-impl=none), or use a non-gated activation (e.g. squared_relu)." ) if self.num_moe_experts is not None and self.num_moe_experts <= 0: @@ -2185,11 +2200,20 @@ def __post_init__(self): "for inference_optimized transformer implementation." ) - if self.inference_disable_torch_grouped_mm: - assert self.transformer_impl == "inference_optimized", ( - "inference_disable_torch_grouped_mm is only supported " - "for inference_optimized transformer implementation." - ) + # Convert moe_ggemm_* strings to enums + _ggemm_valid = {b.name for b in MoEGroupedGemmBackend} + for field_name, allowed in [ + ("moe_ggemm_training", {"te"}), + ("moe_ggemm_inference_cg", {"flashinfer", "torch"}), + ("moe_ggemm_inference_no_cg", {"torch", "te"}), + ]: + val = getattr(self, field_name) + if isinstance(val, str): + if val not in allowed: + raise ValueError( + f"{field_name} must be one of {sorted(allowed)}, got '{val}'" + ) + object.__setattr__(self, field_name, MoEGroupedGemmBackend[val]) if self.batch_invariant_mode: assert ( diff --git a/tests/unit_tests/transformer/moe/test_moe_inference_utils.py b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py index 710143a846a..02304b65018 100644 --- a/tests/unit_tests/transformer/moe/test_moe_inference_utils.py +++ b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py @@ -13,9 +13,19 @@ from megatron.core.transformer.moe.moe_inference_utils import ( compute_expert_offsets, compute_local_tokens_per_expert, + permute_tokens, + unpermute_tokens, ) +def _make_routing_map(num_tokens, topk, num_total_experts, device="cuda"): + """Create a routing map with topk distinct experts per token (matching real router).""" + return torch.stack([ + torch.randperm(num_total_experts, device=device)[:topk] + for _ in range(num_tokens) + ]).to(torch.int32) + + def _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts): """PyTorch reference: count (token, k) pairs routed to each local expert.""" flat = routing_map.flatten() @@ -37,9 +47,7 @@ def test_random_routing(self): num_total_experts, num_local_experts = 64, 8 local_expert_start = 16 - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) @@ -124,9 +132,7 @@ def test_local_expert_start_at_zero(self): num_local_experts = 4 local_expert_start = 0 - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) @@ -141,9 +147,7 @@ def test_local_experts_at_end(self): num_local_experts = 4 local_expert_start = 28 # experts 28..31 - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) @@ -158,9 +162,7 @@ def test_large_batch(self): num_local_experts = 16 local_expert_start = 48 - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) @@ -192,9 +194,7 @@ def test_non_power_of_two_tokens(self): num_local_experts = 4 local_expert_start = 12 - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) @@ -264,9 +264,7 @@ def test_end_to_end_with_histogram(self): num_local_experts = 8 local_expert_start = 16 - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) tokens_per_expert = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) offsets, counters = compute_expert_offsets(tokens_per_expert) @@ -292,13 +290,13 @@ def test_sweep(self, num_tokens, topk, num_local_experts, num_total_experts): """Sweep across token counts, topk, and expert configurations.""" if num_local_experts > num_total_experts: pytest.skip("num_local_experts > num_total_experts") + if topk > num_total_experts: + pytest.skip("topk > num_total_experts") torch.manual_seed(num_tokens * 1000 + topk * 100 + num_local_experts) local_expert_start = (num_total_experts - num_local_experts) // 2 - routing_map = torch.randint( - 0, num_total_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" - ) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) @@ -339,3 +337,277 @@ def test_sweep(self, num_local_experts, max_count): f" got={offsets.tolist()}" ) assert torch.equal(counters, ref_offsets), "counters should equal offsets initially" + + +class TestPermuteTokens: + """Tests for the _permute_tokens_kernel triton kernel.""" + + def _run_permute( + self, num_tokens, topk, hidden_dim, num_total_experts, num_local_experts, local_expert_start, + seed=42, + ): + """Run the full histogram -> prefix_sum -> permute pipeline and verify correctness. + + Checks: + 1. Each permuted row's hidden state matches hidden_states[source_token_indices[i]] + 2. Each expert group contains exactly the right set of source token indices + 3. Total routed count matches tokens_per_expert.sum() + """ + if num_local_experts > num_total_experts: + pytest.skip("num_local_experts > num_total_experts") + if num_total_experts % num_local_experts != 0: + pytest.skip("num_total_experts not divisible by num_local_experts") + if local_expert_start + num_local_experts > num_total_experts: + pytest.skip("local expert range exceeds num_total_experts") + if topk > num_total_experts: + pytest.skip("topk > num_total_experts") + + torch.manual_seed(seed) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + hidden_states = torch.randn(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + probs = torch.randn(num_tokens, topk, dtype=torch.float32, device="cuda") + + # Triton pipeline + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + _, atomic_counters = compute_expert_offsets(tokens_per_expert) + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, routing_map, atomic_counters, + local_expert_start, num_local_experts, + ) + + total_routed = tokens_per_expert.sum().item() + + # Check 1: every permuted row matches its source token + for i in range(total_routed): + src = source_token_indices[i].item() + assert 0 <= src < num_tokens, ( + f"Row {i}: source_token_indices={src} out of bounds [0, {num_tokens})" + ) + assert torch.equal(permuted_hidden[i], hidden_states[src]), ( + f"Row {i}: permuted_hidden doesn't match hidden_states[{src}]" + ) + # We don't store k_idx in the output, so we can't look up the exact + # prob. Instead, verify it's one of the topk probs for this token. + assert permuted_probs[i] in probs[src], ( + f"Row {i}: permuted_probs={permuted_probs[i].item()} not in probs[{src}]" + ) + + # Check 2: each expert group has the right set of source token indices + offset = 0 + for e in range(num_local_experts): + count = tokens_per_expert[e].item() + expert_global_id = local_expert_start + e + + # Expected: all token indices that have this expert in their routing_map + expected_tokens = set() + for t in range(num_tokens): + for k in range(topk): + if routing_map[t, k].item() == expert_global_id: + expected_tokens.add(t) + + # Actual: source token indices in this group + actual_tokens = set(source_token_indices[offset:offset + count].tolist()) + assert actual_tokens == expected_tokens, ( + f"Expert {e}: token set mismatch.\n" + f" expected={sorted(expected_tokens)}\n" + f" actual={sorted(actual_tokens)}" + ) + + offset += count + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256, 1024]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_basic(self, topk, hidden_dim, num_total_experts): + """Basic test with moderate sizes, sweep topk, hidden_dim, num_total_experts.""" + self._run_permute( + num_tokens=64, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=4, local_expert_start=8, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256, 1024]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_single_token(self, topk, hidden_dim, num_total_experts): + """Single token, sweep topk, hidden_dim, num_total_experts.""" + self._run_permute( + num_tokens=1, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=8, local_expert_start=0, + ) + + @pytest.mark.internal + def test_no_local_hits(self): + """All tokens routed to non-local experts.""" + torch.manual_seed(0) + num_tokens, topk, hidden_dim = 32, 4, 64 + num_local_experts = 4 + local_expert_start = 100 # way above any expert ID in [0, 16) + + routing_map = _make_routing_map(num_tokens, topk, 16) + hidden_states = torch.randn(num_tokens, hidden_dim, device="cuda") + probs = torch.randn(num_tokens, topk, device="cuda") + + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + _, atomic_counters = compute_expert_offsets(tokens_per_expert) + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, routing_map, atomic_counters, + local_expert_start, num_local_experts, + ) + + assert tokens_per_expert.sum().item() == 0 + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256, 1024]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_all_to_one_expert(self, topk, hidden_dim, num_total_experts): + """All tokens routed to a single local expert — max atomic contention.""" + self._run_permute( + num_tokens=32, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=1, local_expert_start=0, + seed=99, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [100, 200, 350]) + @pytest.mark.parametrize("num_total_experts", [8, 16, 32, 64, 128]) + def test_non_power_of_two_hidden_dim(self, topk, hidden_dim, num_total_experts): + """Non-power-of-2 dims, sweep topk, hidden_dim, num_total_experts.""" + self._run_permute( + num_tokens=32, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=4, local_expert_start=4, + ) + + +class TestUnpermuteTokens: + """Tests for the _unpermute_tokens_kernel triton kernel.""" + + def _run_end_to_end( + self, num_tokens, topk, hidden_dim, num_total_experts, num_local_experts, + local_expert_start, seed=42, + ): + """Run the full permute -> unpermute pipeline and compare to PyTorch reference. + + The reference computes: + output[t, :] = sum over local experts e assigned to token t: + expert_output[permuted_row_for(t, e), :] * prob[t, k_for(e)] + We simulate expert_output as random data (as if the grouped GEMM ran). + """ + if num_local_experts > num_total_experts: + pytest.skip("num_local_experts > num_total_experts") + if num_total_experts % num_local_experts != 0: + pytest.skip("num_total_experts not divisible by num_local_experts") + if local_expert_start + num_local_experts > num_total_experts: + pytest.skip("local expert range exceeds num_total_experts") + if topk > num_total_experts: + pytest.skip("topk > num_total_experts") + + torch.manual_seed(seed) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + hidden_states = torch.randn(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + probs = torch.randn(num_tokens, topk, dtype=torch.float32, device="cuda").abs() + + # Permute + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + _, atomic_counters = compute_expert_offsets(tokens_per_expert) + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, routing_map, atomic_counters, + local_expert_start, num_local_experts, + ) + + total_routed = tokens_per_expert.sum().item() + + # Simulate expert output (random, as if grouped GEMM ran on permuted_hidden) + expert_output = torch.randn( + permuted_hidden.shape[0], hidden_dim, dtype=torch.float32, device="cuda" + ) + + # Triton unpermute — pass atomic_counters so kernel reads total_routed from GPU + result = unpermute_tokens( + expert_output, permuted_probs, source_token_indices, num_tokens, + atomic_counters, + ) + + # PyTorch reference: scatter-add expert_output[i] * prob[i] to source token + ref = torch.zeros(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + for i in range(total_routed): + src = source_token_indices[i].item() + ref[src] += expert_output[i] * permuted_probs[i] + + # Tolerance depends on dtype: bf16 atomic adds have ~0.01 precision + atol = 1e-2 if expert_output.dtype == torch.bfloat16 else 1e-5 + assert torch.allclose(result, ref, atol=atol), ( + f"Unpermute mismatch: max diff={torch.max(torch.abs(result - ref)).item()}" + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + def test_basic(self, topk, hidden_dim): + """Basic unpermute test, sweep topk and hidden_dim.""" + self._run_end_to_end( + num_tokens=64, topk=topk, hidden_dim=hidden_dim, + num_total_experts=32, num_local_experts=4, local_expert_start=8, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + def test_single_token(self, topk, hidden_dim): + """Single token unpermute, sweep topk and hidden_dim.""" + self._run_end_to_end( + num_tokens=1, topk=topk, hidden_dim=hidden_dim, + num_total_experts=32, num_local_experts=8, local_expert_start=0, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_all_to_one_expert(self, topk, hidden_dim, num_total_experts): + """All tokens to one local expert — tests single-expert accumulation.""" + self._run_end_to_end( + num_tokens=32, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=1, local_expert_start=0, + seed=99, + ) + + @pytest.mark.internal + def test_no_local_hits(self): + """No local tokens — unpermute should return all zeros.""" + torch.manual_seed(0) + num_tokens, hidden_dim = 32, 64 + + # No permuted rows — atomic_counters with a single zero entry + expert_output = torch.empty(0, hidden_dim, dtype=torch.float32, device="cuda") + permuted_probs = torch.empty(0, dtype=torch.float32, device="cuda") + source_token_indices = torch.empty(0, dtype=torch.int32, device="cuda") + atomic_counters = torch.zeros(1, dtype=torch.int32, device="cuda") + + result = unpermute_tokens( + expert_output, permuted_probs, source_token_indices, num_tokens, atomic_counters, + ) + + expected = torch.zeros(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + assert torch.equal(result, expected) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_sweep(self, topk, hidden_dim, num_total_experts): + """Sweep across topk, hidden_dim, num_total_experts.""" + self._run_end_to_end( + num_tokens=64, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=4, local_expert_start=8, + seed=topk * 1000 + hidden_dim, + )