diff --git a/gpt_builders.py b/gpt_builders.py index c8b4efa3075..37a23e3434b 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -53,7 +53,6 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ ) ) elif args.num_experts: - assert not (config.transformer_impl == "inference_optimized") # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config, diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index 1303f61c9d2..a7b325ca2ba 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -133,7 +133,7 @@ def adjust_batch_dims_for_expert_parallelism( strict: bool, decode_only_cuda_graphs: bool, explicit_chunked_prefill: bool, - cuda_graph_mixed_prefill_count: int, + smallest_non_decode_cuda_graph_size: int, ep_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Optional["InferenceBatchDimensions"]: """Adjusted cuda graph batch dimensions for expert parallelism. @@ -176,9 +176,9 @@ def adjust_batch_dims_for_expert_parallelism( is_any_ep_rank_in_non_decode = sync_tensor[1].item() == 1 # We force eager mode for scenarios where some ranks will run with CUDA graphs - # while others will not. Without this check, the all-to-all communication in the + # while others will not. Without this check, communication in the # expert routing layer would pad up to the maximum capacity only for the ranks that - # are using CUDA graphs in this step, leading to a NCCL hang. + # are using CUDA graphs in this step, leading to a hang. # This can happen in the following cases: # 1. If we only allow decode CUDA graphs but some ranks are running non-decode batches # 2. Some ranks are running explicit chunked prefill requests @@ -203,7 +203,7 @@ def adjust_batch_dims_for_expert_parallelism( # graph while prefill ranks match a coarser mixed graph, which would # produce inconsistent token counts across EP ranks. if is_any_ep_rank_in_non_decode and not strict: - adjusted_token_count = max(adjusted_token_count, cuda_graph_mixed_prefill_count) + adjusted_token_count = max(adjusted_token_count, smallest_non_decode_cuda_graph_size) adjusted_batch_dim = InferenceBatchDimensions( token_count=adjusted_token_count, @@ -303,7 +303,7 @@ def generate_cuda_graph_batch_dimensions_list( tp_size: int, num_cuda_graphs: Optional[int], cuda_graph_max_tokens: int, - cuda_graph_mixed_prefill_count: Optional[int], + cuda_graph_mixed_prefill_request_count: Optional[int], max_requests: int, max_tokens: int, max_sequence_length: int, @@ -339,7 +339,7 @@ def generate_cuda_graph_batch_dimensions_list( tp_size: Tensor parallel size num_cuda_graphs: Number of CUDA graphs to generate cuda_graph_max_tokens: Maximum tokens for CUDA graphs - cuda_graph_mixed_prefill_count: Number of mixed prefill requests for CUDA graphs + cuda_graph_mixed_prefill_request_count: Number of mixed prefill requests for CUDA graphs max_requests: Maximum number of requests max_tokens: Maximum total tokens max_sequence_length: Maximum sequence length @@ -409,8 +409,8 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int if num_cuda_graphs is None: cuda_graph_batch_dimensions_list = [] elif ( - not cuda_graph_mixed_prefill_count - or cuda_graph_mixed_prefill_count <= 0 + not cuda_graph_mixed_prefill_request_count + or cuda_graph_mixed_prefill_request_count <= 0 or not use_cuda_graphs_for_non_decode_steps ): # decode only # Use decode-specific token counts for decode-only graphs @@ -426,14 +426,14 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int for size in cuda_graph_prefill_token_counts: add_if_valid( token_count=size, - prefill_req_count=min(cuda_graph_mixed_prefill_count, max_requests), + prefill_req_count=min(cuda_graph_mixed_prefill_request_count, max_requests), decode_req_count=min(size, max_requests) - - min(cuda_graph_mixed_prefill_count, max_requests), + - min(cuda_graph_mixed_prefill_request_count, max_requests), ) # We need to ensure the prefill requests are shorter than the max sequence length, # considering the one decode token is used for prefill request construction prefill_only_minimal_num = max( - cuda_graph_mixed_prefill_count, + cuda_graph_mixed_prefill_request_count, math.ceil(size / max(1, max_sequence_length - 1)), ) if prefill_only_minimal_num < max_requests: @@ -474,7 +474,7 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int def match_graph_config( real_batch_dim: InferenceBatchDimensions, cuda_graph_batch_dimensions_list: List[InferenceBatchDimensions], - cuda_graph_mixed_prefill_count: int, + smallest_non_decode_cuda_graph_size: int, strict: bool = False, decode_only_cuda_graphs: bool = False, explicit_chunked_prefill: bool = False, @@ -509,7 +509,7 @@ def match_graph_config( decode_only_cuda_graphs=decode_only_cuda_graphs, explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, - cuda_graph_mixed_prefill_count=cuda_graph_mixed_prefill_count, + smallest_non_decode_cuda_graph_size=smallest_non_decode_cuda_graph_size, ) if adjusted_batch_dim is None: diff --git a/megatron/core/inference/communication/torch_symm_triton/__init__.py b/megatron/core/inference/communication/torch_symm_triton/__init__.py index ca58663d9ec..967dc8329f1 100644 --- a/megatron/core/inference/communication/torch_symm_triton/__init__.py +++ b/megatron/core/inference/communication/torch_symm_triton/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -from .collectives import multimem_all_gather, multimem_reduce_scatter +from .collectives import multimem_all_gather, multimem_all_gather_fused, multimem_reduce_scatter from .fused_collectives import fused_multimem_rs_add_norm_ag +from .utils import are_tensors_nvls_eligible, is_device_nvls_capable diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index 4bc4dbde42b..cf2003c8595 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -23,43 +23,39 @@ from .barrier import symm_mem_sync from .multimem_asm import ld_128, st_128 -from .utils import get_flat_tid, sync_threads +from .utils import are_tensors_nvls_eligible, get_flat_tid, sync_threads @triton.jit -def _multimem_all_gather_kernel( - local_ptr, - multicast_ptr, - signal_pad_ptrs, - numel, - BLOCK_SIZE: tl.constexpr, - NUMEL_PER_THREAD: tl.constexpr, - RANK: tl.constexpr, - WORLD_SIZE: tl.constexpr, +def _ag_phase( + local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE ): """ - Triton kernel to perform multicast all-gather over nvlink using multimem instructions. - """ - # an all-gather is simply a multicast store operation - # we only need a barrier at the end to ensure visibility of writes + Core all-gather phase: load from local memory, multicast-store to symmetric buffer. + This is the building block for both single-tensor and fused multi-tensor all-gathers. + Each thread handles 128-bit (NUMEL_PER_THREAD elements) at a time. + byte_offset locates the tensor within the multicast buffer. + + """ pid = tl.program_id(axis=0) tid = get_flat_tid() - # From this point on, we pretend each element is 128-bit - numel = numel // NUMEL_PER_THREAD - numel_per_rank = tl.cdiv(numel, WORLD_SIZE) + numel_128 = numel // NUMEL_PER_THREAD + numel_per_rank = tl.cdiv(numel_128, WORLD_SIZE) block_start = pid * BLOCK_SIZE while block_start < numel_per_rank: offsets = block_start + tid mask = offsets < numel_per_rank - # Each pointer points to a 128-bit bit pack - # RANK * numel_per_rank -> brings us to the start of our rank's segment - # offsets -> brings us to the right offset within our rank's segment + # byte_offset // 8 -> converts byte offset to uint64 offset + # RANK * numel_per_rank -> start of our rank's segment + # * 2 -> each 128-bit pack is 2 uint64s multicast_ptrs = ( - multicast_ptr.to(tl.pointer_type(tl.uint64)) + (RANK * numel_per_rank + offsets) * 2 + multicast_ptr.to(tl.pointer_type(tl.uint64)) + + byte_offset // 8 + + (RANK * numel_per_rank + offsets) * 2 ) local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 (x, y, z, w) = ld_128(local_ptrs, mask=mask, multicast_op=False) @@ -67,6 +63,23 @@ def _multimem_all_gather_kernel( block_start += tl.num_programs(axis=0) * BLOCK_SIZE + +@triton.jit +def _multimem_all_gather_kernel( + local_ptr, + multicast_ptr, + signal_pad_ptrs, + numel, + byte_offset, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): + """Single-tensor multicast all-gather kernel.""" + _ag_phase( + local_ptr, multicast_ptr, byte_offset, numel, BLOCK_SIZE, NUMEL_PER_THREAD, RANK, WORLD_SIZE + ) sync_threads() symm_mem_sync( signal_pad_ptrs, @@ -78,54 +91,68 @@ def _multimem_all_gather_kernel( ) -def multimem_all_gather( - output_tensor: torch.Tensor, - input_tensor: torch.Tensor, - symm_mem_hdl: _SymmetricMemory, - **kwargs, -) -> torch.Tensor: +@triton.jit +def _multimem_all_gather_3_kernel( + local_ptr_0, + local_ptr_1, + local_ptr_2, + multicast_ptr, + signal_pad_ptrs, + numel_0, + byte_offset_0, + numel_1, + byte_offset_1, + numel_2, + byte_offset_2, + BLOCK_SIZE: tl.constexpr, + NUMEL_PER_THREAD: tl.constexpr, + RANK: tl.constexpr, + WORLD_SIZE: tl.constexpr, +): """ - Calls a multicast all-gather triton kernel on the given tensor. - Output tensor must be a symmetric memory buffer. - Input tensor can be a regular torch tensor - Arguments: - output_tensor: torch.Tensor - output tensor to be all-gathered into - input_tensor: torch.Tensor - input tensor to be all-gathered from - symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for output_tensor - Returns: - torch.Tensor - all-gathered tensor, which is output_tensor + Fused 3-tensor multicast all-gather. Processes three tensors in sequence + then synchronizes once, eliminating 2 kernel launches and 2 barriers + compared to three separate multimem_all_gather calls. """ - assert HAVE_TRITON, "Triton is required for multimem all-gather." - - config = { - "max_num_blocks": kwargs.get("max_num_blocks", 24), - "num_warps": kwargs.get("num_warps", 32), - "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), - } - assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - numel_per_thread = 128 // (input_tensor.element_size() * 8) - - assert ( - output_tensor.numel() % numel_per_thread == 0 - ), "The number of elements must be 128-bit aligned." - - num_threads = triton.cdiv(output_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) - - _multimem_all_gather_kernel[(num_blocks, 1, 1)]( - input_tensor.data_ptr(), - symm_mem_hdl.multicast_ptr, - symm_mem_hdl.signal_pad_ptrs_dev, - numel=output_tensor.numel(), - BLOCK_SIZE=config["BLOCK_SIZE"], - NUMEL_PER_THREAD=numel_per_thread, - RANK=symm_mem_hdl.rank, - WORLD_SIZE=symm_mem_hdl.world_size, - num_warps=config["num_warps"], + _ag_phase( + local_ptr_0, + multicast_ptr, + byte_offset_0, + numel_0, + BLOCK_SIZE, + NUMEL_PER_THREAD, + RANK, + WORLD_SIZE, + ) + _ag_phase( + local_ptr_1, + multicast_ptr, + byte_offset_1, + numel_1, + BLOCK_SIZE, + NUMEL_PER_THREAD, + RANK, + WORLD_SIZE, + ) + _ag_phase( + local_ptr_2, + multicast_ptr, + byte_offset_2, + numel_2, + BLOCK_SIZE, + NUMEL_PER_THREAD, + RANK, + WORLD_SIZE, + ) + sync_threads() + symm_mem_sync( + signal_pad_ptrs, + None, + RANK, + WORLD_SIZE, + hasPreviousMemAccess=True, + hasSubsequentMemAccess=True, ) - - return output_tensor @triton.jit @@ -175,47 +202,147 @@ def _multimem_reduce_scatter_kernel( block_start += tl.num_programs(axis=0) * BLOCK_SIZE -def multimem_reduce_scatter( +# ── Python wrappers ───────────────────────────────────────────────────────── + +_DEFAULT_KERNEL_CONFIG = {"max_num_blocks": 128, "num_warps": 32, "BLOCK_SIZE": 1024} + + +def _kernel_launch_config(element_size: int, max_numel: int, world_size: int, **kwargs): + """Compute kernel launch config shared by all collective wrappers. + + Args: + element_size: bytes per element (e.g. 2 for bf16). + max_numel: largest tensor numel (determines grid size). + world_size: number of ranks. + + Returns: + (numel_per_thread, num_blocks, config) tuple. + """ + config = {k: kwargs.get(k, v) for k, v in _DEFAULT_KERNEL_CONFIG.items()} + numel_per_thread = 128 // (element_size * 8) + num_threads = triton.cdiv(max_numel // numel_per_thread, world_size) + num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + return numel_per_thread, num_blocks, config + + +def multimem_all_gather( output_tensor: torch.Tensor, input_tensor: torch.Tensor, symm_mem_hdl: _SymmetricMemory, + byte_offset: int = 0, **kwargs, ) -> torch.Tensor: """ - Calls a multicast reduce-scatter triton kernel on the given tensor. - Input tensor must be a symmetric memory buffer. - Output tensor can be a regular torch tensor - Arguments: - output_tensor: torch.Tensor - output tensor to be reduce-scattered into - input_tensor: torch.Tensor - input tensor to be reduce-scattered from - symm_mem_hdl: _SymmetricMemory - handle to the symmetric memory buffer for input_tensor - **kwargs: Additional keyword arguments for kernel configuration: - max_num_blocks (int, optional): The maximum number of blocks to launch. - num_warps (int, optional): The number of warps per block. - BLOCK_SIZE (int, optional): The BLOCK_SIZE parameter for the kernel. - Returns: - torch.Tensor - reduce-scattered tensor, which is output_tensor + Multicast all-gather for a single tensor. + Output tensor must be a symmetric memory buffer. + Input tensor can be a regular torch tensor. """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." + assert are_tensors_nvls_eligible( + input_tensor + ), "Input tensor must be 16-byte divisible on Hopper+ for NVLS." + assert ( + output_tensor.numel() % input_tensor.numel() == 0 + and output_tensor.numel() // input_tensor.numel() == symm_mem_hdl.world_size + ), "Output numel must be exactly world_size * input numel for all-gather." - assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." + numel_per_thread, num_blocks, config = _kernel_launch_config( + input_tensor.element_size(), output_tensor.numel(), symm_mem_hdl.world_size, **kwargs + ) + _multimem_all_gather_kernel[(num_blocks, 1, 1)]( + input_tensor.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel=output_tensor.numel(), + byte_offset=byte_offset, + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) + + return output_tensor + + +def multimem_all_gather_fused( + output_0: torch.Tensor, + input_0: torch.Tensor, + byte_offset_0: int, + output_1: torch.Tensor, + input_1: torch.Tensor, + byte_offset_1: int, + output_2: torch.Tensor, + input_2: torch.Tensor, + byte_offset_2: int, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> None: + """ + Fused 3-tensor multicast all-gather. Equivalent to calling multimem_all_gather + three times but with a single kernel launch and a single barrier. + + All tensors must share the same symmetric memory handle. + """ + assert HAVE_TRITON, "Triton is required for multimem all-gather." + assert are_tensors_nvls_eligible( + input_0, input_1, input_2 + ), "All input tensors must be 16-byte divisible on Hopper+ for NVLS." + for inp, out in [(input_0, output_0), (input_1, output_1), (input_2, output_2)]: + assert ( + out.numel() % inp.numel() == 0 and out.numel() // inp.numel() == symm_mem_hdl.world_size + ), "Output numel must be exactly world_size * input numel for all-gather." + + max_numel = max(output_0.numel(), output_1.numel(), output_2.numel()) + + numel_per_thread, num_blocks, config = _kernel_launch_config( + input_0.element_size(), max_numel, symm_mem_hdl.world_size, **kwargs + ) + _multimem_all_gather_3_kernel[(num_blocks, 1, 1)]( + input_0.data_ptr(), + input_1.data_ptr(), + input_2.data_ptr(), + symm_mem_hdl.multicast_ptr, + symm_mem_hdl.signal_pad_ptrs_dev, + numel_0=output_0.numel(), + byte_offset_0=byte_offset_0, + numel_1=output_1.numel(), + byte_offset_1=byte_offset_1, + numel_2=output_2.numel(), + byte_offset_2=byte_offset_2, + BLOCK_SIZE=config["BLOCK_SIZE"], + NUMEL_PER_THREAD=numel_per_thread, + RANK=symm_mem_hdl.rank, + WORLD_SIZE=symm_mem_hdl.world_size, + num_warps=config["num_warps"], + ) - config = { - "max_num_blocks": kwargs.get("max_num_blocks", 24), - "num_warps": kwargs.get("num_warps", 32), - "BLOCK_SIZE": kwargs.get("BLOCK_SIZE", 1024), - } +def multimem_reduce_scatter( + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + symm_mem_hdl: _SymmetricMemory, + **kwargs, +) -> torch.Tensor: + """ + Multicast reduce-scatter for a single tensor. + Input tensor must be a symmetric memory buffer. + Output tensor can be a regular torch tensor. + """ + assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - numel_per_thread = 128 // (output_tensor.element_size() * 8) - + assert are_tensors_nvls_eligible( + output_tensor + ), "Output tensor must be 16-byte divisible on Hopper+ for NVLS." assert ( - input_tensor.numel() % numel_per_thread == 0 - ), "The number of elements must be 128-bit aligned." - - num_threads = triton.cdiv(input_tensor.numel() // numel_per_thread, symm_mem_hdl.world_size) - num_blocks = min(triton.cdiv(num_threads, config["BLOCK_SIZE"]), config["max_num_blocks"]) + input_tensor.numel() % output_tensor.numel() == 0 + and input_tensor.numel() // output_tensor.numel() == symm_mem_hdl.world_size + ), "Input numel must be exactly world_size * output numel for reduce-scatter." + numel_per_thread, num_blocks, config = _kernel_launch_config( + output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs + ) _multimem_reduce_scatter_kernel[(num_blocks, 1, 1)]( output_tensor.data_ptr(), symm_mem_hdl.multicast_ptr, diff --git a/megatron/core/inference/communication/torch_symm_triton/utils.py b/megatron/core/inference/communication/torch_symm_triton/utils.py index 785481dfba6..3cc6dd8dcc0 100644 --- a/megatron/core/inference/communication/torch_symm_triton/utils.py +++ b/megatron/core/inference/communication/torch_symm_triton/utils.py @@ -4,6 +4,8 @@ from unittest.mock import MagicMock +import torch + from megatron.core.utils import null_decorator try: @@ -15,6 +17,27 @@ triton.jit = null_decorator +def is_device_nvls_capable(device: torch.device) -> bool: + """Check if the device supports NVLS (multicast) collectives. + Requires CUDA Hopper+ (SM >= 9).""" + return device.type == "cuda" and torch.cuda.get_device_properties(device).major >= 9 + + +def are_tensors_nvls_eligible(*tensors: torch.Tensor) -> bool: + """Check if tensors are eligible for NVLS (multicast) collectives. + + Requirements: + - Hopper+ GPU (SM >= 9) + - All tensor byte sizes are divisible by 16 (128-bit), since NVLS + kernels process data in 128-bit chunks. + """ + if not tensors: + return False + return is_device_nvls_capable(tensors[0].device) and all( + t.element_size() * t.numel() % 16 == 0 for t in tensors + ) + + @triton.jit def get_tid(): """ diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index b2e9955bf66..a0f39b13075 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -538,7 +538,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC tp_size=tp_size, num_cuda_graphs=inference_config.num_cuda_graphs, cuda_graph_max_tokens=self.max_requests, - cuda_graph_mixed_prefill_count=inference_config.cuda_graph_mixed_prefill_count, + cuda_graph_mixed_prefill_request_count=inference_config.cuda_graph_mixed_prefill_count, max_requests=self.max_requests, max_tokens=self.max_tokens, max_sequence_length=self.max_sequence_length, @@ -546,7 +546,10 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC ) ) - self.cuda_graph_mixed_prefill_count = inference_config.cuda_graph_mixed_prefill_count + self.smallest_non_decode_cuda_graph_size = min( + inference_config.cuda_graph_mixed_prefill_count, self.max_requests + ) + self._using_cuda_graph_this_step = False # Deal with chunked prefill self.enable_chunked_prefill = inference_config.enable_chunked_prefill @@ -1343,14 +1346,17 @@ def initialize_attention_state( best_graph = CUDAGraphBatchDimensionBuilder.match_graph_config( batch_dimensions, self.cuda_graph_batch_dimensions_list, + smallest_non_decode_cuda_graph_size=self.smallest_non_decode_cuda_graph_size, strict=self.is_hybrid_model, decode_only_cuda_graphs=(not self.use_cuda_graphs_for_non_decode_steps), explicit_chunked_prefill=self.is_chunked_prefill_enabled() and self.is_hybrid_model, ep_group=self.expert_model_parallel_group, - cuda_graph_mixed_prefill_count=self.cuda_graph_mixed_prefill_count, ) self._using_cuda_graph_this_step = best_graph is not None + if construct_graph_dimensions is not None: + assert self._using_cuda_graph_this_step + if is_expert_parallel_dummy_cuda_graph_step and not self.using_cuda_graph_this_step(): # If we are here, this means that CUDAGraphBatchDimensionBuilder.match_graph_config # could not find a compatible cuda graph for the dummy forward step. diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index a10b7e3d60f..d06ab5c53d2 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -41,7 +41,12 @@ from megatron.core.inference.text_generation_controllers.text_generation_controller import ( TextGenerationController, ) -from megatron.core.inference.utils import Counter, await_process_call +from megatron.core.inference.utils import ( + Counter, + await_process_call, + set_inference_cuda_graphed_iteration_for_ep_inference, + unset_inference_cuda_graphed_iteration_for_ep_inference, +) from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.enums import CudaGraphScope @@ -291,6 +296,16 @@ def create_cuda_graphs(self, reset_context: bool = True): for graph in context.cuda_graph_batch_dimensions_list: logging.info(graph) + # Enable inference dispatcher for EP during graph capture + model_config = controller.inference_wrapped_model.model.config + is_inference_optimized_ep = ( + model_config.transformer_impl == "inference_optimized" + and model_config.expert_model_parallel_size > 1 + ) + if is_inference_optimized_ep: + unwrapped_model = controller.inference_wrapped_model.model + set_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) + tbar = enumerate(context.cuda_graph_batch_dimensions_list) if HAVE_TQDM: tbar = tqdm(tbar, total=len(context.cuda_graph_batch_dimensions_list)) @@ -318,6 +333,10 @@ def create_cuda_graphs(self, reset_context: bool = True): context.reset() + # Disable inference dispatcher after graph capture + if is_inference_optimized_ep: + unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) + # Memory usage. time_end = time.time() mem_stats_end = torch.cuda.memory_stats() @@ -1391,7 +1410,7 @@ async def async_bookkeep( mem = torch.cuda.memory_stats() step_type = "decode" if context_state["is_decode_only"] else "non-decode" output_str = ( - "* rank %d | step %d | %s ... time: %.3f%s ... " + "* rank %d | step %d | %s ... time: %.3f ms%s ... " "reqs: a %d/%d, p %d, w %d, f %d, e %d ... " "blocks: a %d/%d, p %d/%d ... " "mem: tensors %d, alloc %.1f gb, res %.1f gb." @@ -1399,7 +1418,7 @@ async def async_bookkeep( self.rank, step_count, datetime.now().strftime("%H:%M:%S"), - step_time, + step_time * 1000, ( " [%s + real config %s + cuda graph %s]" % ( @@ -1707,7 +1726,7 @@ async def _ep_group_has_work(self, local_work: int) -> bool: # Note that it is important to use a non-blocking asyncio-friendly all-reduce here. # The user may have other tasks running in the event loop that need to be serviced. # Do not using a torch.distributed blocking all-reduce here using nccl/gloo. - # We have tried that and it blocks the event loop is megatron-rl. + # We have tried that and it blocks the event loop in megatron-rl. max_global_work = await self.expert_parallel_zmq_communicator.all_reduce_max(local_work) else: max_global_work = local_work diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index e0ab771d01a..7e50f58e3e6 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -530,6 +530,12 @@ def _dynamic_step_context_init( moe_pad_experts_for_cuda_graph_inference = ( self.model_config.moe_pad_experts_for_cuda_graph_inference ) + is_inference_optimized = self.model_config.transformer_impl == "inference_optimized" + if is_inference_optimized: + assert not moe_pad_experts_for_cuda_graph_inference, ( + "moe_pad_experts_for_cuda_graph_inference cannot be True when " + "transformer_impl is 'inference_optimized'" + ) if moe_pad_experts_for_cuda_graph_inference: if context.using_cuda_graph_this_step(): capacity_factor = model_config.num_moe_experts / model_config.moe_router_topk @@ -842,6 +848,11 @@ def dummy_forward(self): context = self.inference_wrapped_model.inference_context # if no cuda graphs, directly use dummy forward if not context.cuda_graph_batch_dimensions_list: + # initialize symmetric memory if needed + unwrapped_model = unwrap_model(self.inference_wrapped_model.model) + model_config = get_model_config(unwrapped_model) + if model_config.transformer_impl == "inference_optimized": + context.maybe_initialize_symmetric_memory() return self.inference_wrapped_model.dummy_forward() # attempt to use cuda-graph if possible diff --git a/megatron/core/inference/utils.py b/megatron/core/inference/utils.py index 0bdaff64be1..770a592e9c8 100644 --- a/megatron/core/inference/utils.py +++ b/megatron/core/inference/utils.py @@ -132,6 +132,34 @@ def set_decode_expert_padding(model, set_to: bool = False, capacity_factor: int router.config.moe_pad_expert_input_to_capacity = bool(set_to) +def set_inference_cuda_graphed_iteration_for_ep_inference(model): + """Enable CUDA graph compatibility for expert parallel inference. + + Sets a flag in all MoELayers indicating the current iteration is being + captured/executed in a CUDA graph. This allows the dispatcher to adjust + its behavior for CUDA graph compatibility. + """ + global moe_layer_cache + if moe_layer_cache is None: + _init_moe_expert_cache(model) + + for moe_layer in moe_layer_cache: + moe_layer.set_inference_cuda_graphed_iteration() + + +def unset_inference_cuda_graphed_iteration_for_ep_inference(model): + """Disable CUDA graph compatibility for expert parallel inference. + + Clears the flag in all MoELayers, restoring standard dispatcher behavior. + """ + global moe_layer_cache + if moe_layer_cache is None: + _init_moe_expert_cache(model) + + for moe_layer in moe_layer_cache: + moe_layer.unset_inference_cuda_graphed_iteration() + + def tensor_swap(x, src_idxs, dst_idxs): """ Swap x[src_idxs] and x[dst_idxs] diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index a4fc3165ba5..ebb979772f0 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -3,14 +3,24 @@ import warnings from abc import abstractmethod -from typing import Optional, Protocol, cast +from typing import Optional, Protocol, Tuple, cast +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TERowParallelGroupedLinear, +) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.mlp import MLPSubmodules, TEActivationFunctionBuilder -from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLPSubmodules +from megatron.core.transformer.moe.experts import ( + GroupedMLP, + InferenceGroupedMLP, + SequentialMLP, + TEGroupedMLPSubmodules, +) from megatron.core.transformer.torch_norm import LayerNormBuilder, WrappedTorchNorm from megatron.core.typed_torch import not_none +from megatron.core.utils import is_te_min_version try: import apex # pylint: disable=unused-import @@ -181,7 +191,8 @@ def activation_func(self) -> TEActivationFunctionBuilder | None: def grouped_mlp_modules( self, moe_use_grouped_gemm: bool, moe_use_legacy_grouped_gemm: bool - ) -> tuple[type, MLPSubmodules | TEGroupedMLPSubmodules | None]: - raise NotImplementedError( - "MOE is not supported with inference optimized transformer implementation." + ) -> Tuple[type, Optional[MLPSubmodules]]: + """Which module and submodules to use for grouped mlp""" + return InferenceGroupedMLP, MLPSubmodules( + linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear ) diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 87b97eeac96..a9a40660f98 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -8,7 +8,10 @@ InferenceSpecProvider, LocalSpecProvider, ) -from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec_for_backend +from megatron.core.models.gpt.moe_module_specs import ( + get_inference_optimized_moe_spec, + get_moe_module_spec_for_backend, +) from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.enums import AttnMaskType, LayerType from megatron.core.transformer.identity_op import IdentityOp @@ -76,6 +79,9 @@ def get_gpt_layer_with_inference_submodules( qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, qk_l2_norm: Optional[bool] = False, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> TransformerLayerSubmodules: """Use these submodules for inference optimized linear layers. Args: @@ -86,14 +92,17 @@ def get_gpt_layer_with_inference_submodules( assert HAVE_TE, "--transformer-impl inference_optimized requires transformer engine" backend = InferenceSpecProvider() - mlp = get_mlp_module_spec_for_backend( - backend=backend, - num_experts=None, - moe_grouped_gemm=False, - moe_use_legacy_grouped_gemm=False, - use_te_op_fuser=False, - use_te_activation_func=False, - ) + if num_experts is not None: + mlp = get_inference_optimized_moe_spec() + else: + mlp = get_mlp_module_spec_for_backend( + backend=backend, + num_experts=None, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + use_te_op_fuser=False, + use_te_activation_func=False, + ) if multi_latent_attention: assert qk_l2_norm is False, "qk_l2_norm is not supported with MLA." @@ -148,7 +157,7 @@ def get_gpt_layer_with_inference_submodules( ), ), self_attn_bda=get_bias_dropout_add, - pre_mlp_layernorm=IdentityOp, + pre_mlp_layernorm=backend.layer_norm() if num_experts else IdentityOp, mlp=mlp, mlp_bda=get_bias_dropout_add, sharded_state_dict_keys_map={ @@ -557,6 +566,21 @@ def get_gpt_decoder_layer_specs( use_kitchen_attention=config.use_kitchen_attention, kitchen_attention_backend=config.kitchen_attention_backend, ) + elif config.transformer_impl == "inference_optimized": + layer_norm_impl = TENorm + dense_layer_spec = get_gpt_layer_with_inference_spec( + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + qk_l2_norm=qk_l2_norm, + ) + moe_layer_spec = get_gpt_layer_with_inference_spec( + qk_layernorm=config.qk_layernorm, + multi_latent_attention=config.multi_latent_attention, + qk_l2_norm=qk_l2_norm, + num_experts=config.num_moe_experts, + moe_grouped_gemm=config.moe_grouped_gemm, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, + ) else: layer_norm_impl = LNImpl dense_layer_spec = get_gpt_layer_local_spec( @@ -648,6 +672,8 @@ def get_gpt_decoder_block_spec( if use_transformer_engine: layer_norm_impl = TENorm + elif config.transformer_impl == "inference_optimized": + layer_norm_impl = TENorm else: layer_norm_impl = LNImpl # Block spec. diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py index 62ee4537cfc..4b0d5640b46 100755 --- a/megatron/core/models/gpt/moe_module_specs.py +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -3,9 +3,14 @@ from typing import Optional from megatron.core.extensions.transformer_engine_spec_provider import TESpecProvider -from megatron.core.models.backends import BackendSpecProvider, LocalSpecProvider +from megatron.core.models.backends import ( + BackendSpecProvider, + InferenceSpecProvider, + LocalSpecProvider, +) from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.router import InferenceTopKRouter from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec @@ -16,7 +21,17 @@ def get_moe_module_spec( moe_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> ModuleSpec: - """Helper function to get module spec for MoE""" + """Helper function to get module spec for MoE. + + Called by mamba_layer_specs.py for standard (non-inference) MoE specs. + The GPT layer specs call get_moe_module_spec_for_backend directly. + + Args: + use_te: Whether to use Transformer Engine. + num_experts: Number of experts. + moe_grouped_gemm: Whether to use grouped GEMM. + moe_use_legacy_grouped_gemm: Whether to use legacy grouped GEMM. + """ if use_te is not None and use_te: backend: BackendSpecProvider = TESpecProvider() else: @@ -66,3 +81,38 @@ def get_moe_module_spec_for_backend( metainfo={"fuse_pre_mlp_layernorm": False}, ) return moe_module_spec + + +def get_inference_optimized_moe_spec() -> ModuleSpec: + """MoE module spec for inference-optimized transformer impl. + + Uses InferenceSpecProvider to select inference-optimized modules: + InferenceTopKRouter, InferenceGroupedMLP. MoELayer detects inference mode + via config.transformer_impl and sets up the inference dispatcher internally. + + Called by mamba_layer_specs.py and gpt_layer_specs.py. + """ + backend = InferenceSpecProvider() + activation_func = backend.activation_func() + + expert_module, expert_submodule = backend.grouped_mlp_modules(True, False) + if expert_submodule is not None: + expert_submodule.activation_func = activation_func + + experts = ModuleSpec(module=expert_module, submodules=expert_submodule) + shared_experts = ModuleSpec( + module=SharedExpertMLP, + submodules=MLPSubmodules( + linear_fc1=backend.column_parallel_linear(), + linear_fc2=backend.row_parallel_linear(), + activation_func=activation_func, + ), + ) + + return ModuleSpec( + module=MoELayer, + submodules=MoESubmodules( + router=InferenceTopKRouter, experts=experts, shared_experts=shared_experts + ), + metainfo={"fuse_pre_mlp_layernorm": False}, + ) diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 6ca628475be..957f20847fc 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -8,7 +8,10 @@ TERowParallelLinear, ) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add -from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec +from megatron.core.models.gpt.moe_module_specs import ( + get_inference_optimized_moe_spec, + get_moe_module_spec, +) from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules @@ -41,6 +44,9 @@ moe_use_legacy_grouped_gemm=False, ) +# Inference-optimized MoE spec +moe_inference = get_inference_optimized_moe_spec() + # MTP block spec for Mamba - provides norms and projection only. # Inner layers are built by MultiTokenPredictionLayer using nested MambaStack @@ -173,10 +179,10 @@ ), ), moe_layer=ModuleSpec( - # TODO (rwaleffe): change this to be an "MoELayer" to work with CudaGraphs? + # Use inference-optimized MoE layer for end-to-end CUDA graph support module=TransformerLayer, submodules=TransformerLayerSubmodules( - pre_mlp_layernorm=TENorm, mlp=moe, mlp_bda=get_bias_dropout_add + pre_mlp_layernorm=TENorm, mlp=moe_inference, mlp_bda=get_bias_dropout_add ), ), mtp_block_spec=_mamba_mtp_block_spec, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index 5bf413dce3f..65003171480 100644 --- a/megatron/core/parallel_state.py +++ b/megatron/core/parallel_state.py @@ -138,8 +138,9 @@ # Memory buffers to avoid dynamic memory allocation _GLOBAL_MEMORY_BUFFER = None -# Global symmetric memory buffer for inference -_GLOBAL_SYMMETRIC_MEMORY_BUFFER = None +# Global symmetric memory buffers for inference +_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None +_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None # List of all process groups # Used for updating the timeout for all process groups @@ -2014,14 +2015,24 @@ def _set_global_memory_buffer(): def _set_global_symmetric_memory_buffer(): """Initialize global buffer.""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - assert _GLOBAL_SYMMETRIC_MEMORY_BUFFER is None, "global memory buffer is already initialized" + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is None + ), "global symmetric memory buffer for TP is already initialized" + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is None + ), "global symmetric memory buffer for EP is already initialized" - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = GlobalSymmetricMemoryBuffer( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = GlobalSymmetricMemoryBuffer( size_in_mb=256, # todo: set from an argument? process_group=get_tensor_model_parallel_group(), ) + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = GlobalSymmetricMemoryBuffer( + size_in_mb=256, # todo: set from an argument? + process_group=get_expert_model_parallel_group(), + ) + def get_global_memory_buffer(): """Return the global GlobalMemoryBuffer object""" @@ -2029,12 +2040,20 @@ def get_global_memory_buffer(): return _GLOBAL_MEMORY_BUFFER -def get_global_symmetric_memory_buffer(): +def get_global_symmetric_memory_buffer_tp(): """Return the global GlobalSymmetricMemoryBuffer object""" assert ( - _GLOBAL_SYMMETRIC_MEMORY_BUFFER is not None + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is not None ), "global symmetric memory buffer is not initialized" - return _GLOBAL_SYMMETRIC_MEMORY_BUFFER + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP + + +def get_global_symmetric_memory_buffer_ep(): + """Return the global GlobalSymmetricMemoryBuffer object""" + assert ( + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is not None + ), "global symmetric memory buffer is not initialized" + return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP def destroy_global_memory_buffer(): @@ -2045,8 +2064,9 @@ def destroy_global_memory_buffer(): def destroy_global_symmetric_memory_buffer(): """Sets the global symmetric memory buffer to None""" - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None def get_all_ranks(): @@ -2127,8 +2147,11 @@ def destroy_model_parallel(): global _GLOBAL_MEMORY_BUFFER _GLOBAL_MEMORY_BUFFER = None - global _GLOBAL_SYMMETRIC_MEMORY_BUFFER - _GLOBAL_SYMMETRIC_MEMORY_BUFFER = None + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None + + global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP + _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None global _DATA_PARALLEL_GROUP_GLOO if ( diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index 0addc64a65f..cb367e43bd6 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -9,6 +9,7 @@ TERowParallelLinear, ) from megatron.core.inference.communication.torch_symm_triton import ( + are_tensors_nvls_eligible, fused_multimem_rs_add_norm_ag, multimem_all_gather, multimem_reduce_scatter, @@ -16,7 +17,7 @@ from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor from megatron.core.inference.quantization.utils import mm_mxfp8 from megatron.core.model_parallel_config import ModelParallelConfig -from megatron.core.parallel_state import get_global_symmetric_memory_buffer +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_tp from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -108,6 +109,8 @@ def __init__( config.sequence_parallel ), "--transformer-impl=inference_optimized requires --sequence-parallel" + self.triton_nvls_kernels_allowed = not config.inference_disable_triton_nvls_kernels + # Boolean to be toggled externally for skipping norm and all-gather. # This is used when enabling fused reduce-scatter + add + rms-norm + all-gather # in tensor parallelism. In this case, the preceeding RowParallelLinear layer @@ -120,7 +123,7 @@ def _maybe_allocate_symmetric_buffer(self, x: torch.Tensor): """ symm_mem_buffer_dims = list(x.size()) symm_mem_buffer_dims[0] *= self.tp_size - symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) return symm_mem_buffer @@ -133,16 +136,14 @@ def _all_gather(self, x: torch.Tensor, symm_mem_buffer: dict) -> None: if self.tp_size == 1: return x - # 1. check if bf16 - is_bf16 = x.dtype == torch.bfloat16 - # 2. check if hopper or newer - is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 - # 3. check if symmetric memory buffer is available - has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None - can_use_custom_nvls_collectives = ( - is_bf16 and is_hopper_or_newer and has_enough_symmetric_memory + # Check input only: if input is 16-byte divisible, the output + # (world_size * input) is too. + can_use_nvls = ( + self.triton_nvls_kernels_allowed + and are_tensors_nvls_eligible(x) + and symm_mem_buffer["handle"] is not None ) - if can_use_custom_nvls_collectives: + if can_use_nvls: # do multimem all gather multimem_all_gather(symm_mem_buffer["tensor"], x, symm_mem_buffer["handle"]) return symm_mem_buffer["tensor"] @@ -221,6 +222,10 @@ def __init__( config.sequence_parallel ), "--transformer-impl=inference_optimized requires --sequence-parallel" + self.triton_nvls_kernels_allowed = not getattr( + config, 'inference_disable_triton_nvls_kernels', False + ) + # Placeholder for next layer norm weights for fused # reduce-scatter + add + rms-norm + all-gather self.next_layer_norm_weights = None @@ -233,27 +238,27 @@ def _matmul_reduce_scatter(self, x, residual=None): and perform an NVLS multicast reduce-scatter. If that is not possible, it will revert to torch.dist (NCCL) reduce-scatter. """ - # 1. check if bf16 - is_bf16 = x.dtype == torch.bfloat16 - # 2. check if mxfp8 use_mxfp8 = self.config.fp8_recipe == "mxfp8" - # 3. check if hopper or newer - is_hopper_or_newer = torch.cuda.get_device_properties(x.device).major >= 9 - # 4. attempt to ask for symmetric memory symm_mem_buffer_dims = list(x.size()) if use_mxfp8: # Remove batch dimension for FlashInfer mxfp8 del symm_mem_buffer_dims[1] symm_mem_buffer_dims[-1] = self.weight.size(0) - symm_mem_buffer = get_global_symmetric_memory_buffer().maybe_get_tensor( + symm_mem_buffer = get_global_symmetric_memory_buffer_tp().maybe_get_tensor( symm_mem_buffer_dims, dtype=x.dtype ) - has_enough_symmetric_memory = symm_mem_buffer["handle"] is not None - can_use_custom_nvls_collectives = ( - is_bf16 and is_hopper_or_newer and has_enough_symmetric_memory + + # RS requires bf16 (hardware multimem reduce is bf16-only). + # Check the matmul output shape: if it is NVLS-eligible, the RS output + # (world_size times smaller on dim 0) is too. + can_use_nvls = ( + self.triton_nvls_kernels_allowed + and x.dtype == torch.bfloat16 + and are_tensors_nvls_eligible(x) + and symm_mem_buffer["handle"] is not None ) - if can_use_custom_nvls_collectives: + if can_use_nvls: # Write output of matmul directly onto the symmetric memory buffer x = _apply_linear(x, self.weight, self.config, out=symm_mem_buffer["tensor"]) diff --git a/megatron/core/transformer/enums.py b/megatron/core/transformer/enums.py index d57e24887ab..ead3aef6cec 100644 --- a/megatron/core/transformer/enums.py +++ b/megatron/core/transformer/enums.py @@ -67,6 +67,14 @@ class AttnBackend(enum.Enum): auto = 5 +class MoEGroupedGemmBackend(enum.Enum): + """Backend for MoE grouped GEMM operations.""" + + te = 1 # Transformer Engine GroupedGEMM + torch = 2 # torch._grouped_mm + flashinfer = 3 # FlashInfer fused cutlass_fused_moe kernel + + class CudaGraphScope(enum.Enum): """Cuda Graph Scope - defines which parts of the model to capture.""" diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 3671404fcfe..8eda2d3d5a4 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -42,6 +42,7 @@ TEActivationFunctionBuilder, apply_swiglu_sharded_factory, ) +from megatron.core.transformer.enums import MoEGroupedGemmBackend from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe import grouped_gemm_util as gg from megatron.core.transformer.moe.moe_utils import ( @@ -55,6 +56,7 @@ sharded_state_dict_default, ) from megatron.core.typed_torch import apply_module, not_none +from megatron.core.utils import is_torch_min_version try: import transformer_engine as te # pylint: disable=unused-import @@ -67,6 +69,14 @@ HAVE_TE = False +try: + import flashinfer.fused_moe as fused_moe + from flashinfer.fused_moe.core import ActivationType + + HAVE_FLASHINFER = True +except ImportError: + HAVE_FLASHINFER = False + logger = logging.getLogger(__name__) @@ -720,6 +730,64 @@ def _apply_bias(intermediate_parallel, bias_parallel, tokens_per_expert, permute .to(intermediate_parallel.dtype) ) + def bias_act_func(self, intermediate_parallel, bias_parallel, permuted_probs): + """ + Applies bias and activation function to the output of linear_fc1. + """ + if self.config.use_te_activation_func: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + intermediate_parallel = self.activation_func(intermediate_parallel) + if permuted_probs is not None: + original_dtype = intermediate_parallel.dtype + intermediate_parallel = intermediate_parallel * permuted_probs + intermediate_parallel = intermediate_parallel.to(original_dtype) + elif self.config.bias_activation_fusion and permuted_probs is not None: + if self.activation_func == F.silu and self.config.gated_linear_unit: + # dtype is handled inside the fused kernel + intermediate_parallel = weighted_bias_swiglu_impl( + intermediate_parallel, + bias_parallel, + permuted_probs, + self.config.activation_func_fp8_input_store, + ) + elif self.activation_func == quick_gelu and self.config.gated_linear_unit: + intermediate_parallel = weighted_bias_quick_geglu_impl( + intermediate_parallel, + bias_parallel, + permuted_probs, + self.config.activation_func_fp8_input_store, + self.config.glu_linear_offset, + self.config.activation_func_clamp_value, + ) + else: + raise ValueError("Only support fusion of swiglu and quick_gelu in TEGroupedMLP.") + elif self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu and permuted_probs is not None: + assert bias_parallel is None, "Bias is not supported with fused weighted squared relu." + intermediate_parallel = weighted_squared_relu_impl( + intermediate_parallel, permuted_probs + ) + else: + if self.config.gated_linear_unit: + + def glu(x): + x_glu, x_linear = torch.chunk(x, 2, dim=-1) + if (val := self.config.activation_func_clamp_value) is not None: + x_glu = x_glu.clamp(min=None, max=val) + x_linear = x_linear.clamp(min=-val, max=val) + return self.config.activation_func(x_glu) * ( + x_linear + self.config.glu_linear_offset + ) + + intermediate_parallel = glu(intermediate_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel) + original_dtype = intermediate_parallel.dtype + if permuted_probs is not None: + intermediate_parallel = intermediate_parallel * permuted_probs + intermediate_parallel = intermediate_parallel.to(original_dtype) + return intermediate_parallel + def forward( self, permuted_local_hidden_states: torch.Tensor, @@ -772,74 +840,15 @@ def forward( forced_released_tensors=[permuted_local_hidden_states], ) - def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): - if self.config.use_te_activation_func: - if bias_parallel is not None: - intermediate_parallel = intermediate_parallel + bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) - if permuted_probs is not None: - original_dtype = intermediate_parallel.dtype - intermediate_parallel = intermediate_parallel * permuted_probs - intermediate_parallel = intermediate_parallel.to(original_dtype) - elif self.config.bias_activation_fusion: - if self.activation_func == F.silu and self.config.gated_linear_unit: - # dtype is handled inside the fused kernel - intermediate_parallel = weighted_bias_swiglu_impl( - intermediate_parallel, - bias_parallel, - permuted_probs, - self.config.activation_func_fp8_input_store, - ) - elif self.activation_func == quick_gelu and self.config.gated_linear_unit: - intermediate_parallel = weighted_bias_quick_geglu_impl( - intermediate_parallel, - bias_parallel, - permuted_probs, - self.config.activation_func_fp8_input_store, - self.config.glu_linear_offset, - self.config.activation_func_clamp_value, - ) - else: - raise ValueError( - "Only support fusion of swiglu and quick_gelu in TEGroupedMLP." - ) - elif ( - self.activation_func == squared_relu and self.config.use_fused_weighted_squared_relu - ): - assert bias_parallel is None - intermediate_parallel = weighted_squared_relu_impl( - intermediate_parallel, permuted_probs - ) - else: - if self.config.gated_linear_unit: - - def glu(x): - x_glu, x_linear = torch.chunk(x, 2, dim=-1) - if (val := self.config.activation_func_clamp_value) is not None: - x_glu = x_glu.clamp(min=None, max=val) - x_linear = x_linear.clamp(min=-val, max=val) - return self.config.activation_func(x_glu) * ( - x_linear + self.config.glu_linear_offset - ) - - intermediate_parallel = glu(intermediate_parallel) - else: - intermediate_parallel = self.activation_func(intermediate_parallel) - original_dtype = intermediate_parallel.dtype - intermediate_parallel = intermediate_parallel * permuted_probs - intermediate_parallel = intermediate_parallel.to(original_dtype) - return intermediate_parallel - if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: bias_act_output = self.activation_checkpoint.checkpoint( - bias_act_func, fc1_output, bias_parallel, permuted_probs + self.bias_act_func, fc1_output, bias_parallel, permuted_probs ) else: with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: - bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) - + bias_act_output = self.bias_act_func(fc1_output, bias_parallel, permuted_probs) output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) @@ -912,6 +921,236 @@ def backward_dw(self): self.linear_fc1.backward_dw() +class InferenceGroupedMLP(TEGroupedMLP): + """Inference-optimized GroupedMLP with GPU-resident offsets. + + Inherits from TEGroupedMLP to reuse weight initialization and checkpoint compatibility. + Supports three forward paths: + - Training: delegates to parent TEGroupedMLP + - Inference + CUDA graphed: FlashInfer cutlass_fused_moe (fused permute + GEMM) + - Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets + """ + + def __init__( + self, + num_local_experts: int, + config: TransformerConfig, + submodules: MLPSubmodules, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + # Initialize parent TEGroupedMLP (creates linear_fc1, linear_fc2) + super().__init__( + num_local_experts=num_local_experts, + config=config, + submodules=submodules, + pg_collection=pg_collection, + ) + + # TE's GroupedLinear stores per-expert weights as separate parameters + # (weight0, weight1, ..., weight{n-1}). We stack them into contiguous tensors + # of shape [num_experts, out_features, in_features] for torch._grouped_mm and + # FlashInfer's cutlass_fused_moe. Per-expert views are registered so that + # load_state_dict still writes into the contiguous buffers. + self._build_concatenated_weights() + + self.is_inference_cuda_graphed_iteration = False + + # torch._grouped_mm requires PyTorch >= 2.10 + self._torch_grouped_mm_available = ( + is_torch_min_version("2.10") + and hasattr(torch, '_grouped_mm') + ) + + if HAVE_FLASHINFER: + self._flashinfer_activation_type = self._resolve_flashinfer_activation_type() + + def _resolve_flashinfer_activation_type(self): + """Map megatron activation config to FlashInfer ActivationType.""" + assert ( + HAVE_FLASHINFER + ), "flashinfer-python is required to resolve FlashInfer activation type." + func = self.config.activation_func + if func == F.silu: + return ActivationType.Silu + elif func == F.gelu: + return ActivationType.Gelu + elif func == F.relu: + return ActivationType.Relu + elif func == squared_relu: + return ActivationType.Relu2 + raise ValueError(f"No FlashInfer ActivationType mapping for activation_func={func}") + + def set_inference_cuda_graphed_iteration(self): + """Enable CUDA-graphed iteration mode.""" + self.is_inference_cuda_graphed_iteration = True + + def unset_inference_cuda_graphed_iteration(self): + """Disable CUDA-graphed iteration mode.""" + self.is_inference_cuda_graphed_iteration = False + + def _build_concatenated_weights(self): + """Create big contiguous weight tensors with per-expert views for checkpoint compatibility. + + Creates _fc1_weight and _fc2_weight as contiguous tensors of shape + [num_experts, out_features, in_features]. Replaces TE's individual weight{i} + parameters with views into these tensors. + + This allows: + - load_state_dict to load into weight{i} views -> writes into big tensor + - forward() to use big tensor directly with torch._grouped_mm or FlashInfer + """ + # Get device/dtype from existing TE weights + device = self.linear_fc1.weight0.device + dtype = self.linear_fc1.weight0.dtype + + fc1_shape = self.linear_fc1.weight0.shape # [out_features, in_features] + fc2_shape = self.linear_fc2.weight0.shape + + # Create big contiguous tensors + _fc1_weight = torch.empty(self.num_local_experts, *fc1_shape, device=device, dtype=dtype) + _fc2_weight = torch.empty(self.num_local_experts, *fc2_shape, device=device, dtype=dtype) + + # Copy existing TE weights into big tensors, then replace with views + for i in range(self.num_local_experts): + # Copy initialized data + _fc1_weight[i].copy_(getattr(self.linear_fc1, f'weight{i}').data) + _fc2_weight[i].copy_(getattr(self.linear_fc2, f'weight{i}').data) + + # Delete TE's original parameters + delattr(self.linear_fc1, f'weight{i}') + delattr(self.linear_fc2, f'weight{i}') + + # Register views as parameters (checkpoint loads will write into big tensor) + self.linear_fc1.register_parameter(f'weight{i}', torch.nn.Parameter(_fc1_weight[i])) + self.linear_fc2.register_parameter(f'weight{i}', torch.nn.Parameter(_fc2_weight[i])) + + # Register big tensors as non-persistent buffers (for .to() device movement, not saved) + self.register_buffer('_fc1_weight', _fc1_weight, persistent=False) + self.register_buffer('_fc2_weight', _fc2_weight, persistent=False) + + def _flashinfer_forward(self, hidden_states, routing_map, probs): + """FlashInfer fused MoE kernel for CUDA-graphed inference iterations.""" + assert HAVE_FLASHINFER, "flashinfer-python is required for FlashInfer forward path." + assert probs.dtype == torch.float32, "FlashInfer forward path requires fp32 probabilities." + output = fused_moe.cutlass_fused_moe( + hidden_states, + routing_map.int(), + probs, + self._fc1_weight, + self._fc2_weight, + hidden_states.dtype, + quant_scales=None, + activation_type=self._flashinfer_activation_type, + ep_size=self.ep_group.size(), + ep_rank=self.ep_group.rank(), + )[0] + return output, None + + def _torch_grouped_mm_forward( + self, permuted_local_hidden_states, tokens_per_expert, permuted_probs, + inclusive_expert_offsets=None, + ): + if permuted_probs is not None: + permuted_probs = permuted_probs.unsqueeze(-1) + if not tokens_per_expert.is_cuda: + tokens_per_expert = tokens_per_expert.to('cuda') + + if self.config.moe_apply_probs_on_input and permuted_probs is not None: + assert ( + self.config.moe_router_topk == 1 + ), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1." + original_dtype = permuted_local_hidden_states.dtype + permuted_local_hidden_states = permuted_probs * permuted_local_hidden_states + permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype) + permuted_probs = torch.ones_like(permuted_probs) + + if permuted_local_hidden_states.nelement() != 0: + # Reuse precomputed inclusive offsets from the dispatcher if available, + # otherwise compute cumsum on the fly (non-CG eager path). + if inclusive_expert_offsets is not None: + offs = inclusive_expert_offsets + else: + offs = tokens_per_expert.cumsum(0).to(torch.int32) + + fc1_output = torch._grouped_mm( + permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs + ) + + # Activation with routing probabilities + bias_act_output = self.bias_act_func(fc1_output, None, permuted_probs) + + fc2_output = torch._grouped_mm( + bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs + ) + else: + # No tokens allocated - return empty tensor with correct shape + fc2_output = permuted_local_hidden_states + + return fc2_output, None + + def forward( + self, + permuted_local_hidden_states: torch.Tensor, + tokens_per_expert: Optional[torch.Tensor], + permuted_probs: torch.Tensor, + routing_map: Optional[torch.Tensor] = None, + inclusive_expert_offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward pass with three modes: + + - Training: delegates to parent TEGroupedMLP. + - Inference + CUDA graphed: FlashInfer cutlass_fused_moe. tokens_per_expert + is not used in this path; the FlashInfer kernel operates directly on + routing_map. + - Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets. + + Args: + permuted_local_hidden_states: [num_tokens, hidden_size] input hidden states. + tokens_per_expert: [num_experts] number of tokens routed to each expert. + None when using the CUDA-graphed FlashInfer path. + permuted_probs: [num_tokens, topk] routing probabilities. + routing_map: [num_tokens, topk] token-to-expert assignment indices. + Required for the FlashInfer CUDA-graphed path, None otherwise. + inclusive_expert_offsets: [num_experts] precomputed inclusive cumsum of + tokens_per_expert (i.e. offs[i] = end index of expert i). When + provided, torch._grouped_mm reuses these directly instead of + recomputing cumsum. None for FlashInfer and non-CG paths. + """ + if self.training: + return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) + + elif self.is_inference_cuda_graphed_iteration: + if self.config.moe_ggemm_inference_cg == MoEGroupedGemmBackend.flashinfer: + assert routing_map is not None, ( + "routing_map is required for FlashInfer forward pass." + ) + assert HAVE_FLASHINFER, ( + "FlashInfer is not available; cannot use FlashInfer forward pass." + ) + return self._flashinfer_forward( + permuted_local_hidden_states, routing_map, permuted_probs + ) + else: + assert tokens_per_expert is not None, ( + "tokens_per_expert is required for torch grouped_mm forward pass." + ) + return self._torch_grouped_mm_forward( + permuted_local_hidden_states, tokens_per_expert, permuted_probs, + inclusive_expert_offsets=inclusive_expert_offsets, + ) + + elif ( + self.config.moe_ggemm_inference_no_cg == MoEGroupedGemmBackend.torch + and self._torch_grouped_mm_available + ): + return self._torch_grouped_mm_forward( + permuted_local_hidden_states, tokens_per_expert, permuted_probs + ) + + else: + return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) + + class SequentialMLP(MegatronModule): """An implementation of the Experts layer using a sequence of MLP layers. diff --git a/megatron/core/transformer/moe/moe_inference_utils.py b/megatron/core/transformer/moe/moe_inference_utils.py new file mode 100644 index 00000000000..89726be8929 --- /dev/null +++ b/megatron/core/transformer/moe/moe_inference_utils.py @@ -0,0 +1,408 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +Triton kernels for CUDA-graph-compatible MoE token permutation and unpermutation. + +These kernels enable the torch grouped GEMM path to work under CUDA graphs +by keeping all metadata (tokens_per_expert, permutation indices) GPU-resident. +""" + +from unittest.mock import MagicMock + +import torch +from packaging import version + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + if version.parse(triton.__version__) < version.parse("3.4.0") and not torch.cuda.is_available(): + HAVE_TRITON = False + else: + HAVE_TRITON = tl.constexpr(version.parse(triton.__version__) >= version.parse("2.0.0")) +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + + +# --------------------------------------------------------------------------- # +# Kernel: Count tokens per local expert +# --------------------------------------------------------------------------- # +@triton.jit +def _count_local_tokens_kernel( + routing_map_ptr, # [num_tokens, topk] - global expert IDs + tokens_per_expert_ptr, # [num_local_experts] output (must be zero-initialized) + total_pairs, + local_expert_start, + num_local_experts: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Count tokens assigned to each local expert, filtering out non-local experts. + + Each program handles BLOCK_SIZE (token, k) pairs from the routing_map. + Pairs whose assigned expert is not on this rank are ignored. For local + experts, atomically increments the corresponding tokens_per_expert counter. + """ + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_pairs + + expert_ids = tl.load(routing_map_ptr + offsets, mask=mask, other=-1) + local_ids = expert_ids - local_expert_start + is_local = (local_ids >= 0) & (local_ids < num_local_experts) & mask + + # Scatter atomic add: each element adds 1 to its expert's counter + tl.atomic_add(tokens_per_expert_ptr + local_ids, 1, mask=is_local) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def compute_local_tokens_per_expert( + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, +) -> torch.Tensor: + """Count tokens routed to each local expert, filtering out non-local assignments. + + Scans the routing_map for (token, k) pairs whose assigned expert lives on + this rank (global ID in [local_expert_start, local_expert_start + num_local_experts)). + Pairs routed to experts on other ranks are ignored. + + Args: + routing_map (torch.Tensor): Expert assignments, shape [num_tokens, topk]. + Contains global expert IDs. + local_expert_start (int): First global expert index on this rank. + num_local_experts (int): Number of experts on this rank. + + Returns: + torch.Tensor: tokens_per_expert, shape [num_local_experts], dtype int32. + Count of (token, k) pairs assigned to each local expert. + """ + total_pairs = routing_map.numel() + + tokens_per_expert = torch.zeros( + num_local_experts, dtype=torch.int32, device=routing_map.device + ) + + HIST_BLOCK = 256 + hist_grid = ((total_pairs + HIST_BLOCK - 1) // HIST_BLOCK,) + _count_local_tokens_kernel[hist_grid]( + routing_map, + tokens_per_expert, + total_pairs, + local_expert_start, + num_local_experts, + BLOCK_SIZE=HIST_BLOCK, + ) + + return tokens_per_expert + + +# --------------------------------------------------------------------------- # +# Kernel: Exclusive prefix sum (single block) +# --------------------------------------------------------------------------- # +@triton.jit +def _prefix_sum_kernel( + tokens_per_expert_ptr, # [num_local_experts] input + expert_offsets_ptr, # [num_local_experts] output + num_local_experts, + BLOCK_SIZE: tl.constexpr, # next_power_of_2(num_local_experts) +): + """Compute exclusive prefix sum of tokens_per_expert. + + Runs as a single block. Reads tokens_per_expert, computes exclusive prefix + sum via tl.cumsum, and writes expert_offsets. + """ + expert_range = tl.arange(0, BLOCK_SIZE) + mask = expert_range < num_local_experts + histogram = tl.load(tokens_per_expert_ptr + expert_range, mask=mask, other=0) + + # Inclusive prefix sum, then shift to exclusive + inclusive = tl.cumsum(histogram, axis=0) + exclusive = inclusive - histogram + + tl.store(expert_offsets_ptr + expert_range, exclusive, mask=mask) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def compute_expert_offsets( + tokens_per_expert: torch.Tensor, +) -> torch.Tensor: + """Compute exclusive prefix sum of tokens_per_expert. + + Args: + tokens_per_expert (torch.Tensor): Token counts per local expert, + shape [num_local_experts], dtype int32. + + Returns: + torch.Tensor: expert_offsets, shape [num_local_experts]. + Exclusive prefix sum: expert_offsets[i] is the start index of + expert i's tokens in the permuted buffer. Passed to permute_tokens + which mutates it in-place via atomic adds, turning the exclusive + start offsets into inclusive end offsets (i.e. expert_offsets[i] + becomes the end index of expert i's tokens after permutation). + """ + num_local_experts = tokens_per_expert.shape[0] + + expert_offsets = torch.empty_like(tokens_per_expert) + + BLOCK_SIZE = triton.next_power_of_2(num_local_experts) + _prefix_sum_kernel[(1,)]( + tokens_per_expert, + expert_offsets, + num_local_experts, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return expert_offsets + + +# --------------------------------------------------------------------------- # +# Kernel: Permute tokens by expert assignment +# --------------------------------------------------------------------------- # +@triton.jit +def _permute_tokens_kernel( + # Input pointers + hidden_states_ptr, # [num_tokens, hidden_dim] + probs_ptr, # [num_tokens, topk] + routing_map_ptr, # [num_tokens, topk] + # Output pointers + permuted_hidden_ptr, # [output_size, hidden_dim] + permuted_probs_ptr, # [output_size] + source_token_indices_ptr, # [output_size] + # Atomic counters (mutated in-place) + atomic_counters_ptr, # [num_local_experts] + # Dimensions + num_tokens, + hidden_dim, + topk: tl.constexpr, + local_expert_start, + num_local_experts: tl.constexpr, + BLOCK_H: tl.constexpr, # tile size for hidden_dim copy loop +): + """Permute tokens into expert-grouped order. + + Each program handles one (token, k) pair. If the assigned expert is local, + it atomically claims a write position and copies the token's hidden state, + routing probability, and source token index to the output buffers. + The hidden dimension is copied in tiles of BLOCK_H to support large hidden sizes. + """ + pair_idx = tl.program_id(0) + token_idx = pair_idx // topk + k_idx = pair_idx % topk + + if token_idx >= num_tokens: + return + + expert_id = tl.load(routing_map_ptr + token_idx * topk + k_idx) + local_idx = expert_id - local_expert_start + + if local_idx < 0 or local_idx >= num_local_experts: + return + + # Atomically claim a write position + write_pos = tl.atomic_add(atomic_counters_ptr + local_idx, 1) + + # Copy hidden state row in tiles of BLOCK_H + src_row_ptr = hidden_states_ptr + token_idx * hidden_dim + dst_row_ptr = permuted_hidden_ptr + write_pos * hidden_dim + for h_start in tl.range(0, hidden_dim, BLOCK_H): + h_offsets = h_start + tl.arange(0, BLOCK_H) + h_mask = h_offsets < hidden_dim + src_vals = tl.load(src_row_ptr + h_offsets, mask=h_mask) + tl.store(dst_row_ptr + h_offsets, src_vals, mask=h_mask) + + # Copy the routing probability for this (token, k) pair + prob_val = tl.load(probs_ptr + token_idx * topk + k_idx) + tl.store(permuted_probs_ptr + write_pos, prob_val) + + # Record source token index for the unpermute kernel + tl.store(source_token_indices_ptr + write_pos, token_idx) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def permute_tokens( + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + expert_offsets: torch.Tensor, + local_expert_start: int, + num_local_experts: int, +) -> tuple: + """Permute tokens into expert-grouped order for the torch grouped GEMM. + + Scatters tokens into an output buffer where all tokens for expert 0 come + first, then expert 1, etc. Uses expert_offsets (from compute_expert_offsets) + to assign write positions. + + NOTE: expert_offsets is mutated in-place. On entry it contains exclusive + start offsets (expert_offsets[i] = start index of expert i's region). + The kernel atomically increments each entry as it places tokens, so on + exit expert_offsets[i] = end index (inclusive) of expert i's region. + The last entry equals the total number of routed tokens. + + Args: + hidden_states (torch.Tensor): Input hidden states, shape [num_tokens, hidden_dim]. + probs (torch.Tensor): Routing probabilities, shape [num_tokens, topk]. + routing_map (torch.Tensor): Expert assignments, shape [num_tokens, topk]. + Contains global expert IDs. + expert_offsets (torch.Tensor): Write position counters, shape [num_local_experts]. + Initialized to exclusive prefix sum by compute_expert_offsets. + Mutated in-place to inclusive end offsets by the permute kernel. + local_expert_start (int): First global expert index on this rank. + num_local_experts (int): Number of experts on this rank. + + Returns: + tuple: (permuted_hidden_states, permuted_probs, source_token_indices) where: + - permuted_hidden_states: [output_size, hidden_dim] tokens grouped by expert. + - permuted_probs: [output_size] scalar prob per permuted slot. + - source_token_indices: [output_size] original token index per permuted slot. + output_size = num_tokens * min(topk, num_local_experts). + Slots beyond the actual routed token count contain uninitialized data. + """ + num_tokens, hidden_dim = hidden_states.shape + topk = probs.shape[1] + output_size = num_tokens * min(topk, num_local_experts) + + # Allocate output buffers (statically sized for CUDA graph compatibility) + permuted_hidden = torch.empty( + output_size, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device + ) + permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) + source_token_indices = torch.empty(output_size, dtype=torch.int32, device=probs.device) + + total_pairs = num_tokens * topk + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + # After this kernel, expert_offsets is mutated: exclusive start offsets + # become inclusive end offsets (expert_offsets[-1] = total routed tokens). + _permute_tokens_kernel[(total_pairs,)]( + hidden_states, + probs, + routing_map, + permuted_hidden, + permuted_probs, + source_token_indices, + expert_offsets, + num_tokens, + hidden_dim, + topk, + local_expert_start, + num_local_experts, + BLOCK_H=BLOCK_H, + ) + + return permuted_hidden, permuted_probs, source_token_indices + + +# --------------------------------------------------------------------------- # +# Kernel: Unpermute (accumulate expert outputs back to token positions) +# --------------------------------------------------------------------------- # +@triton.jit +def _unpermute_tokens_kernel( + # Input pointers + expert_output_ptr, # [output_size, hidden_dim] + permuted_probs_ptr, # [output_size] + source_token_indices_ptr, # [output_size] + # GPU-resident valid count (read from last atomic counter after permute) + num_routed_slots_ptr, # scalar tensor on GPU + # Output pointer + output_ptr, # [num_tokens, hidden_dim] - must be zero-initialized + # Dimensions + hidden_dim, + BLOCK_H: tl.constexpr, # tile size for hidden_dim loop +): + """Accumulate weighted expert outputs back into original token positions. + + Each program handles one row of the permuted expert output. It reads the + source token index, multiplies the expert output by the routing probability, + and atomically adds the result to the corresponding row in the output buffer. + Multiple experts contributing to the same token are summed via atomic adds. + Rows beyond num_routed_slots (read from GPU) are skipped. + The hidden dimension is processed in tiles of BLOCK_H to support large hidden sizes. + """ + row_idx = tl.program_id(0) + + # Read valid count from GPU — no host sync needed for CUDA graphability + num_valid = tl.load(num_routed_slots_ptr) + if row_idx >= num_valid: + return + + token_idx = tl.load(source_token_indices_ptr + row_idx) + prob = tl.load(permuted_probs_ptr + row_idx) + + src_row_ptr = expert_output_ptr + row_idx * hidden_dim + dst_row_ptr = output_ptr + token_idx * hidden_dim + for h_start in tl.range(0, hidden_dim, BLOCK_H): + h_offsets = h_start + tl.arange(0, BLOCK_H) + h_mask = h_offsets < hidden_dim + expert_vals = tl.load(src_row_ptr + h_offsets, mask=h_mask) + scaled_vals = expert_vals * prob + tl.atomic_add(dst_row_ptr + h_offsets, scaled_vals, mask=h_mask) + + +# --------------------------------------------------------------------------- # +# Python wrapper +# --------------------------------------------------------------------------- # +def unpermute_tokens( + expert_output: torch.Tensor, + permuted_probs: torch.Tensor, + source_token_indices: torch.Tensor, + num_tokens: int, + num_routed_slots: torch.Tensor, +) -> torch.Tensor: + """Unpermute expert outputs back to original token order. + + Accumulates weighted expert outputs into the original token positions. + For each valid permuted row i, computes: + output[source_token_indices[i], :] += expert_output[i, :] * permuted_probs[i] + Multiple experts contributing to the same token are summed via atomic adds. + Rows beyond num_routed_slots are skipped. num_routed_slots is read from GPU + to avoid host synchronization, keeping the pipeline CUDA-graphable. + + Args: + expert_output (torch.Tensor): Expert outputs, shape [output_size, hidden_dim]. + Only the first total_routed rows are valid. + permuted_probs (torch.Tensor): Routing probs, shape [output_size]. + Only the first total_routed entries are valid. + source_token_indices (torch.Tensor): Source token index for each permuted row, + shape [output_size], dtype int32. Only the first total_routed are valid. + num_tokens (int): Number of original tokens (for output allocation). + num_routed_slots (torch.Tensor): 1-element GPU tensor containing the + total number of routed tokens. Must stay on GPU (no host sync). + + Returns: + torch.Tensor: Unpermuted output, shape [num_tokens, hidden_dim]. + Each row is the weighted sum of expert outputs for that token. + """ + output_size, hidden_dim = expert_output.shape + + # Zero-initialized output (atomic adds accumulate into this) + output = torch.zeros( + num_tokens, hidden_dim, dtype=expert_output.dtype, device=expert_output.device + ) + + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _unpermute_tokens_kernel[(output_size,)]( + expert_output, + permuted_probs, + source_token_indices, + num_routed_slots, + output, + hidden_dim, + BLOCK_H=BLOCK_H, + ) + + return output + + diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 3d9d0b092aa..b60e2d6bb1d 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional, Protocol, Union @@ -24,11 +25,31 @@ MoEFlexTokenDispatcher, MoETokenDispatcher, ) +from megatron.core.transformer.enums import MoEGroupedGemmBackend +from megatron.core.transformer.moe.token_dispatcher_inference import ( + InferenceCUDAGraphTokenDispatcher, +) from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.typed_torch import apply_module from megatron.core.utils import internal_api +try: + import flashinfer # pylint: disable=unused-import + + HAVE_FLASHINFER = True +except ImportError: + HAVE_FLASHINFER = False + +if HAVE_FLASHINFER: + try: + import flashinfer_cubin # pylint: disable=unused-import + import flashinfer_jit_cache # pylint: disable=unused-import + + HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = True + except ImportError: + HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE = False + try: import transformer_engine as te # pylint: disable=unused-import @@ -246,10 +267,77 @@ def __init__( if self.shared_expert_overlap: self.token_dispatcher.set_shared_experts(self.shared_experts) + # Inference-optimized mode setup + if config.transformer_impl == "inference_optimized": + if config.moe_ggemm_inference_cg == MoEGroupedGemmBackend.flashinfer: + assert ( + HAVE_FLASHINFER + ), "flashinfer-python is required when moe_ggemm_inference_cg='flashinfer'." + if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: + warnings.warn( + "flashinfer-cubin and/or flashinfer-jit-cache not found. " + "The FlashInfer cutlass kernel will be JIT compiled," + "which may take a long time." + ) + self._setup_inference_mode(pg_collection) + # Cudagraph tensor store for resuming the forward pass from the end of the cudagraph. self.cudagraph_tensor_store = MoECudaGraphTensorStore() self.fwd_execution_map = ["route", "expert_compute", "postprocess"] + def _setup_inference_mode(self, pg_collection): + """Set up inference-optimized token dispatcher and state. + + Called from __init__ when config.transformer_impl == "inference_optimized". + Creates an InferenceCUDAGraphTokenDispatcher alongside the standard dispatcher, + which is swapped in during CUDA-graphed forward passes. + """ + + assert self.config.moe_token_dispatcher_type == "alltoall", ( + f"Inference-optimized MoE requires 'alltoall' dispatcher, " + f"got '{self.config.moe_token_dispatcher_type}'" + ) + self.is_inference_cuda_graphed_iteration = False + self._inference_token_dispatcher = InferenceCUDAGraphTokenDispatcher( + self.num_local_experts, + self.local_expert_indices, + config=self.config, + pg_collection=pg_collection, + ) + + def set_inference_cuda_graphed_iteration(self): + """Enable CUDA-graphed iteration mode on this layer, its router, and its experts. + + Swaps in the inference-optimized token dispatcher and disables + shared expert overlap. + """ + self.is_inference_cuda_graphed_iteration = True + if hasattr(self.router, "set_inference_cuda_graphed_iteration"): + self.router.set_inference_cuda_graphed_iteration() + if hasattr(self.experts, "set_inference_cuda_graphed_iteration"): + self.experts.set_inference_cuda_graphed_iteration() + + if self._inference_token_dispatcher is not None: + self._saved_token_dispatcher = self.token_dispatcher + self.token_dispatcher = self._inference_token_dispatcher + self._saved_shared_expert_overlap = self.shared_expert_overlap + self.shared_expert_overlap = False + + def unset_inference_cuda_graphed_iteration(self): + """Disable CUDA-graphed iteration mode on this layer, its router, and its experts. + + Restores the standard token dispatcher and shared expert overlap setting. + """ + self.is_inference_cuda_graphed_iteration = False + if hasattr(self.router, "unset_inference_cuda_graphed_iteration"): + self.router.unset_inference_cuda_graphed_iteration() + if hasattr(self.experts, "unset_inference_cuda_graphed_iteration"): + self.experts.unset_inference_cuda_graphed_iteration() + + if hasattr(self, "_saved_token_dispatcher"): + self.token_dispatcher = self._saved_token_dispatcher + self.shared_expert_overlap = self._saved_shared_expert_overlap + @maybe_skip_or_early_return_by_cudagraph("route") def route(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): """Compute token routing for preprocessing. @@ -328,7 +416,26 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso dispatched_input, tokens_per_expert, permuted_probs = ( self.token_dispatcher.dispatch_postprocess(hidden_states, probs) ) - expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert, permuted_probs) + if ( + hasattr(self, "_inference_token_dispatcher") + and self.is_inference_cuda_graphed_iteration + ): + if self.config.moe_ggemm_inference_cg == MoEGroupedGemmBackend.flashinfer: + routing_map = self.token_dispatcher.routing_map + expert_output, mlp_bias = self.experts( + dispatched_input, tokens_per_expert, permuted_probs, routing_map=routing_map + ) + else: + # torch backend: dispatcher already permuted tokens; tokens_per_expert is set. + # Pass precomputed inclusive offsets to avoid recomputing cumsum. + expert_output, mlp_bias = self.experts( + dispatched_input, tokens_per_expert, permuted_probs, + inclusive_expert_offsets=self.token_dispatcher.inclusive_expert_offsets, + ) + else: + expert_output, mlp_bias = self.experts( + dispatched_input, tokens_per_expert, permuted_probs + ) assert mlp_bias is None, f"mlp_bias is not supported for {type(self.token_dispatcher)}" output = self.token_dispatcher.combine_preprocess(expert_output) @@ -437,7 +544,7 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None): return output, mlp_bias - if self.moe_layer_recompute: + if self.moe_layer_recompute and self.training: if self.config.fp8 or self.config.fp4: outputs = te_checkpoint( custom_forward, diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 69fcc21d983..51c8b51134f 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -662,6 +662,7 @@ def topk_routing_with_score_function( expert_bias: Optional[torch.Tensor] = None, fused: bool = False, router_replay: Optional['RouterReplay'] = None, + dense_output: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the routing probabilities and map for top-k selection with score function. @@ -684,14 +685,23 @@ def topk_routing_with_score_function( recorded routing sequence. Defaults to None. + dense_output (bool, optional): If True, return dense tensors [num_tokens, topk] instead of + sparse tensors [num_tokens, num_experts]. Defaults to False. Returns: Tuple[torch.Tensor, torch.Tensor]: - - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing - the routing probabilities for each token to each expert. - - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts] - indicating which experts were selected for each token. True values represent - the selected experts. + When dense_output=False (default): + - routing_probs (torch.Tensor): Shape [num_tokens, num_experts]. Sparse tensor + containing the normalized routing probability for each token-expert pair. Non-zero + entries correspond to the top-k selected experts per token. + - routing_map (torch.Tensor): Shape [num_tokens, num_experts]. Boolean mask where + True indicates the token is routed to that expert (i.e. the expert was in the + token's top-k selection). + When dense_output=True: + - probs (torch.Tensor): Shape [num_tokens, topk]. The normalized routing + probabilities for each token's top-k selected experts. + - top_indices (torch.Tensor): Shape [num_tokens, topk]. The expert indices + selected for each token. """ assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." num_tokens, num_experts = logits.shape @@ -774,6 +784,9 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if scaling_factor: probs = probs * scaling_factor + if dense_output: + return probs, top_indices + if torch.are_deterministic_algorithms_enabled(): # build [num_tokens, num_experts] from [num_tokens, topk] routing_probs = torch.zeros_like(logits) diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index e42fd1ca8aa..45cfb59cd40 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -710,3 +710,115 @@ def _save_to_state_dict(self, *args, **kwargs): """Save the state dict of the router.""" self._maintain_float32_expert_bias() # switch to float32 before saving return super()._save_to_state_dict(*args, **kwargs) + + +class InferenceTopKRouter(TopKRouter): + """Inference-only top-k router that strips out training-specific overhead. + + A stripped-down version of TopKRouter that skips z-loss, auxiliary load + balancing losses, token dropping, and expert bias updates. The _forward() + method is @torch.compile()'d and returns dense [num_tokens, topk] tensors + instead of sparse [num_tokens, num_experts] for compatibility with FlashInfer. + + Falls back to the parent TopKRouter.forward() for training or + non-CUDA-graphed inference iterations. + """ + + def __init__( + self, + config: TransformerConfig, + pg_collection: Optional[ProcessGroupCollection] = None, + is_mtp_layer: bool = False, + ) -> None: + """Initialize the specialized inference top-k router. + + Args: + config (TransformerConfig): The configuration for the transformer model. + pg_collection (ProcessGroupCollection, optional): Process groups for MoE operations. + """ + # Enforce constraints before calling super().__init__ + assert config.moe_router_num_groups is None, ( + f"InferenceTopKRouter requires moe_router_num_groups=None, " + f"got {config.moe_router_num_groups}" + ) + assert config.moe_router_score_function in ["sigmoid", "softmax"], ( + f"InferenceTopKRouter requires moe_router_score_function in " + f"['sigmoid', 'softmax'], got '{config.moe_router_score_function}'" + ) + + super().__init__(config=config, pg_collection=pg_collection) + + self.is_inference_cuda_graphed_iteration = False + + def set_inference_cuda_graphed_iteration(self): + """Enable CUDA graph-compatible operations for the router.""" + self.is_inference_cuda_graphed_iteration = True + + def unset_inference_cuda_graphed_iteration(self): + """Disable CUDA graph-compatible operations for the router.""" + self.is_inference_cuda_graphed_iteration = False + + @staticmethod + @torch.compile + def _compiled_topk_routing( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + fused, + router_replay, + dense_output, + ): + return topk_routing_with_score_function( + logits, + topk, + use_pre_softmax=use_pre_softmax, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function=score_function, + expert_bias=expert_bias, + fused=fused, + router_replay=router_replay, + dense_output=dense_output, + ) + + def _forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + logits = self.gating(input).squeeze(1) # [num_tokens, num_experts] + + probs, top_indices = self._compiled_topk_routing( + logits, + self.topk, + use_pre_softmax=self.config.moe_router_pre_softmax, + num_groups=self.config.moe_router_num_groups, + group_topk=self.config.moe_router_group_topk, + scaling_factor=self.config.moe_router_topk_scaling_factor, + score_function=self.score_function, + expert_bias=self.expert_bias, + fused=self.config.moe_router_fusion, + router_replay=self.router_replay, + dense_output=True, + ) + return probs.squeeze(1), top_indices.squeeze(1) + + def forward(self, input: torch.Tensor, padding_mask: Optional[torch.Tensor] = None): + """Simplified forward pass for inference - returns dense tensors only. + + Args: + input (torch.Tensor): Input tensor of shape [seq_length, bsz, hidden_size]. + padding_mask (torch.Tensor, optional): Not used in inference. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - probs: Normalized routing probabilities [num_tokens, topk] + - top_indices: Selected expert indices [num_tokens, topk] + """ + + if self.training or not self.is_inference_cuda_graphed_iteration: + return super().forward(input, padding_mask) + + return self._forward(input, padding_mask) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py new file mode 100644 index 00000000000..9bab4dbd4fa --- /dev/null +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -0,0 +1,401 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +CUDA-graph-compatible token dispatcher for inference. + +This dispatcher is only used during CUDA-graphed inference iterations. It replaces +AlltoAll with AllGather/ReduceScatter for token exchange, keeping all metadata +GPU-resident to avoid host synchronizations that would break CUDA graph capture. + +Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) +on Hopper+ GPUs with BF16, with automatic fallback to NCCL. +""" + +from typing import List, Optional + +import torch + +from megatron.core.inference.communication.torch_symm_triton import ( + are_tensors_nvls_eligible, + multimem_all_gather_fused, + multimem_reduce_scatter, +) +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from megatron.core.transformer.enums import MoEGroupedGemmBackend +from megatron.core.transformer.moe.moe_inference_utils import ( + compute_expert_offsets, + compute_local_tokens_per_expert, + permute_tokens, + unpermute_tokens, +) +from megatron.core.transformer.moe.token_dispatcher import MoEAllGatherTokenDispatcher +from megatron.core.transformer.transformer_config import TransformerConfig + + +class InferenceCUDAGraphTokenDispatcher(MoEAllGatherTokenDispatcher): + """ + CUDA-graph-compatible AllGather token dispatcher for inference. + + Only used during CUDA-graphed inference iterations. Swapped in by + MoELayer.set_inference_cuda_graphed_iteration() before graph capture + and swapped out by MoELayer.unset_inference_cuda_graphed_iteration() after. + + Key features: + - AllGather/ReduceScatter instead of AlltoAll for CUDA graph compatibility + - GPU-resident metadata (no host synchronization) + - NVLS collectives on Hopper+ with automatic NCCL fallback + """ + + def __init__( + self, + num_local_experts: int, + local_expert_indices: List[int], + config: TransformerConfig, + pg_collection: Optional[ProcessGroupCollection] = None, + ) -> None: + """ + Initialize the InferenceCUDAGraphTokenDispatcher. + + Args: + num_local_experts: Number of experts on this rank. + local_expert_indices: Global indices of experts on this rank. + config: Transformer configuration. + pg_collection: Process group collection for distributed ops. + """ + super().__init__( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=pg_collection, + ) + self.topk = config.moe_router_topk + + self.triton_nvls_kernels_allowed = not self.config.inference_disable_triton_nvls_kernels + + # Backend selection for CUDA-graphed grouped GEMM + self._moe_backend = config.moe_ggemm_inference_cg + self._local_expert_start = local_expert_indices[0] + + # Intermediate state for torch backend (set in dispatch_postprocess, used in combine_preprocess) + self._source_token_indices = None + self._num_routed_slots = None + self._permuted_probs = None + self._num_tokens_after_dispatch = None + # Inclusive end offsets per expert after permutation. Reused as the + # `offs` arg to torch._grouped_mm to avoid recomputing cumsum. + # None for the flashinfer backend (not needed). + self.inclusive_expert_offsets = None + + def _maybe_allocate_ag_buffers( + self, routing_map: torch.Tensor, probs: torch.Tensor, hidden_states: torch.Tensor + ) -> dict: + """Allocate a single symmetric memory output buffer for fused all-gather. + + Creates one contiguous symmetric memory buffer sized for the gathered + (global) routing_map, probs, and hidden_states, then returns sliced views + into it. This allows a single fused NVLS all-gather kernel to write all + three outputs in one launch. + + Args: + routing_map (torch.Tensor): Local routing map, shape [local_tokens, topk]. + Boolean or integer tensor mapping each token to its selected experts. + probs (torch.Tensor): Local routing probabilities, shape [local_tokens, topk]. + Normalized weights for each token's selected experts. + hidden_states (torch.Tensor): Local hidden states, shape [local_tokens, hidden_dim]. + + Returns: + dict: A dictionary with the following keys: + - "handle": Symmetric memory handle for NVLS ops, or None if + symmetric memory is unavailable. + - "routing_map": Raw byte view for the gathered routing map output. + - "routing_map_offset": Byte offset of routing_map within the buffer. + - "probs": Raw byte view for the gathered probs output. + - "probs_offset": Byte offset of probs within the buffer. + - "hidden_states": Raw byte view for the gathered hidden states output. + - "hidden_states_offset": Byte offset of hidden_states within the buffer. + When allocation fails, all tensor views are None and offsets are 0. + """ + _NONE = { + "handle": None, + "routing_map": None, + "routing_map_offset": 0, + "probs": None, + "probs_offset": 0, + "hidden_states": None, + "hidden_states_offset": 0, + } + + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + topk = probs.size(-1) + hidden_dim = hidden_states.size(-1) + + result = get_global_symmetric_memory_buffer_ep().maybe_get_tensors( + [ + (global_tokens * topk, routing_map.dtype), + (global_tokens * topk, probs.dtype), + (global_tokens * hidden_dim, hidden_states.dtype), + ] + ) + + if result["handle"] is None: + return _NONE + + (rm_buf, rm_off), (p_buf, p_off), (hs_buf, hs_off) = result["tensors"] + return { + "handle": result["handle"], + "routing_map": rm_buf, + "routing_map_offset": rm_off, + "probs": p_buf, + "probs_offset": p_off, + "hidden_states": hs_buf, + "hidden_states_offset": hs_off, + } + + def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: + """Allocate a symmetric memory buffer for reduce-scatter input. + + The buffer has the same shape and dtype as x so that x can be copied + into it before the NVLS reduce-scatter kernel. + + Args: + x (torch.Tensor): The global hidden states to be reduce-scattered, + shape [global_tokens, hidden_dim]. + + Returns: + dict: A dictionary with keys "handle" (symmetric memory handle, or + None if unavailable) and "tensor" (the allocated buffer, or None). + """ + symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( + list(x.size()), dtype=x.dtype + ) + return symm_mem_buffer + + def token_dispatch(self, hidden_states, probs): + """Gathers tokens from all EP ranks using AllGather. + + Performs all-gather on routing_map (stored in self.routing_map), probs, + and hidden_states so that every rank holds the full global view. + Uses latency-optimized fused NVLS multimem_all_gather on Hopper+ GPUs + with BF16 when symmetric memory is available. Falls back to NCCL otherwise. + + Args: + hidden_states (torch.Tensor): Local hidden states, + shape [local_tokens, hidden_dim]. + probs (torch.Tensor): Local routing probabilities, + shape [local_tokens, topk]. Normalized weights for each token's + selected experts. + + Returns: + tuple: (hidden_states, probs) gathered across all EP ranks. + - hidden_states (torch.Tensor): Shape [global_tokens, hidden_dim]. + - probs (torch.Tensor): Shape [global_tokens, topk]. + Also updates self.routing_map in-place to the gathered + shape [global_tokens, topk]. + """ + if self.ep_size == 1: + return hidden_states, probs + + # 1. Check inputs only: if inputs are 16-byte divisible, + # outputs (world_size * input) are too. + nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible( + hidden_states, probs, self.routing_map + ) + ag_buffers = None + + if nvls_eligible: + # 2. Now attempt to allocate symmetric memory buffers for + # all-gather outputs. If allocation fails, fallback to NCCL. + ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) + + # 3. Can use NVLS if eligible and buffers allocated successfully (handle is not None) + can_use_nvls = nvls_eligible and ag_buffers["handle"] is not None + + if can_use_nvls: + # Capture shapes for reshaping after all-gather + # Output shape: [local_tokens * ep_size, dim] + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + topk = probs.size(1) + hidden_dim = hidden_states.size(1) + routing_map_dtype = self.routing_map.dtype + probs_dtype = probs.dtype + hidden_dtype = hidden_states.dtype + + # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors + multimem_all_gather_fused( + ag_buffers["routing_map"].view( + torch.bfloat16 + ), # .view does not change the underlying data + self.routing_map.view(torch.bfloat16), + ag_buffers["routing_map_offset"], + ag_buffers["probs"].view(torch.bfloat16), + probs.view(torch.bfloat16), + ag_buffers["probs_offset"], + ag_buffers["hidden_states"].view(torch.bfloat16), + hidden_states.view(torch.bfloat16), + ag_buffers["hidden_states_offset"], + ag_buffers["handle"], + ) + self.routing_map = ( + ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) + ) + probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) + hidden_states = ( + ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) + ) + else: + # Fallback to NCCL for all tensors + with torch.no_grad(): + self.routing_map = gather_from_sequence_parallel_region( + self.routing_map, group=self.tp_ep_group + ) + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) + hidden_states = gather_from_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) + + return hidden_states, probs + + def dispatch_postprocess(self, hidden_states, probs): + """Post-process dispatched tokens for expert computation. + + For the flashinfer backend, this is a pass-through: the FlashInfer fused + kernel operates directly on self.routing_map. For the torch backend, + uses triton kernels to permute tokens into expert-grouped order. + + Args: + hidden_states (torch.Tensor): Gathered hidden states, + shape [global_tokens, hidden_dim]. + probs (torch.Tensor): Gathered routing probabilities, + shape [global_tokens, topk]. + + Returns: + tuple: (hidden_states, tokens_per_expert, probs) where: + - flashinfer: tokens_per_expert=None, probs=original probs + - torch: tokens_per_expert=GPU histogram, probs=None + (probs deferred to combine_preprocess unpermute) + """ + if self._moe_backend == MoEGroupedGemmBackend.flashinfer: + return hidden_states, None, probs + + # torch backend: permute tokens into expert-grouped order + self._num_tokens_after_dispatch = hidden_states.size(0) + tokens_per_expert = compute_local_tokens_per_expert( + self.routing_map, self._local_expert_start, self.num_local_experts + ) + exclusive_expert_offsets = compute_expert_offsets(tokens_per_expert) + # permute_tokens mutates the offsets tensor in-place via atomic adds, + # converting exclusive start offsets into inclusive end offsets. + # Example with 3 experts and tokens_per_expert = [2, 3, 1]: + # exclusive_expert_offsets = [0, 2, 5] (start of each expert's region) + # inclusive_expert_offsets = [2, 5, 6] (end of each expert's region) + # The last entry (6) equals the total number of routed tokens. + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, self.routing_map, exclusive_expert_offsets, + self._local_expert_start, self.num_local_experts, + ) + inclusive_expert_offsets = exclusive_expert_offsets # mutated in-place by permute_tokens + + # Cache state for combine_preprocess. + self._source_token_indices = source_token_indices + # Number of (token, expert) slots actually routed to local experts on + # this rank. With topk > 1, a single token may occupy multiple slots. + # Only permuted_hidden[:num_routed_slots] contains real data; the + # remaining rows are uninitialized padding for static CUDA graph shapes. + # Used later in combine_preprocess/unpermute_tokens to read only the + # valid grouped GEMM outputs and skip garbage-padded rows. + # Stored as a 1-element GPU tensor view ([-1:], not [-1]) to avoid a + # device-to-host sync that would break CUDA graphability. + self._num_routed_slots = inclusive_expert_offsets[-1:] + self._permuted_probs = permuted_probs + # Cache the full inclusive offsets so torch._grouped_mm can reuse them + # directly as `offs` without recomputing cumsum on tokens_per_expert. + self.inclusive_expert_offsets = inclusive_expert_offsets + + # probs=None: expert skips prob-weighting, unpermute applies probs instead + return permuted_hidden, tokens_per_expert, None + + def combine_preprocess(self, expert_output): + """Unpermute expert outputs back to original token order. + + For flashinfer, this is a pass-through (FlashInfer already produces + output in token order). For the torch backend, uses the triton unpermute + kernel to scatter-add expert outputs weighted by routing probs. + + Args: + expert_output (torch.Tensor): Output from InferenceGroupedMLP, + shape [output_size, hidden_dim] (torch) or + [global_tokens, hidden_dim] (flashinfer). + + Returns: + torch.Tensor: Output in original token order, + shape [global_tokens, hidden_dim]. + """ + if self._moe_backend == MoEGroupedGemmBackend.flashinfer: + return expert_output + + # torch backend: unpermute with routing probs + output = unpermute_tokens( + expert_output, + self._permuted_probs, + self._source_token_indices, + self._num_tokens_after_dispatch, + self._num_routed_slots, + ) + return output + + def token_combine(self, hidden_states): + """Combines expert outputs across EP ranks using Reduce-Scatter. + + Reduces the global expert output (summing contributions from each rank) + and scatters the result so each rank receives its local token slice. + Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs + with BF16 when symmetric memory is available. Falls back to NCCL otherwise. + + Args: + hidden_states (torch.Tensor): Combined expert output after routing + weights have been applied, shape [global_tokens, hidden_dim]. + + Returns: + torch.Tensor: Local slice of the reduced output, + shape [local_tokens, hidden_dim] where + local_tokens = global_tokens // ep_size. + """ + if self.ep_size == 1: + return hidden_states + + # Compute output shape first — check NVLS eligibility on the output, + # since if the smaller output is 16-byte divisible, the input is too. + output_shape = list(hidden_states.size()) + output_shape[0] = hidden_states.size(0) // self.ep_size + output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) + + # Check output only: if output is 16-byte divisible, input (world_size * output) is too. + nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible(output) + rs_buffer = None + + if nvls_eligible: + rs_buffer = self._maybe_allocate_rs_buffer(hidden_states) + + can_use_nvls = nvls_eligible and rs_buffer["handle"] is not None + + if can_use_nvls: + # Copy input to symmetric memory for reduce-scatter + rs_buffer["tensor"].copy_(hidden_states) + + # Use latency-optimized NVLS reduce-scatter + multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) + return output + else: + # Fallback to NCCL + hidden_states = reduce_scatter_to_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) + return hidden_states + diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 3d82a6e39c6..95d9af30a01 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -11,7 +11,7 @@ from megatron.core.enums import Fp4Recipe, Fp8Recipe from megatron.core.quantization.quant_config import RecipeConfig -from megatron.core.transformer.enums import AttnBackend, CudaGraphScope +from megatron.core.transformer.enums import AttnBackend, CudaGraphScope, MoEGroupedGemmBackend from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout from .._rank_utils import log_single_rank @@ -914,6 +914,22 @@ class TransformerConfig(ModelParallelConfig): inference_fuse_tp_communication: bool = False """ If true, uses a fused reduce-scatter-residual-norm-allgather kernel during inference. """ + inference_disable_triton_nvls_kernels: bool = False + """ If true, disables the use of Triton NVLS kernels during inference. """ + + moe_ggemm_training: str = "te" + """Backend for training grouped GEMM. Only 'te' is supported.""" + + moe_ggemm_inference_cg: str = "flashinfer" + """Backend for CUDA-graphed inference grouped GEMM. + 'flashinfer': FlashInfer's fused cutlass_fused_moe kernel (default). + 'torch': triton permute/unpermute kernels + torch._grouped_mm.""" + + moe_ggemm_inference_no_cg: str = "torch" + """Backend for non-CUDA-graphed inference grouped GEMM. + 'torch': torch._grouped_mm with GPU-resident cumsum offsets (default). + 'te': TE GroupedGEMM fallback.""" + mrope_section: Optional[List[int]] = None """ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. """ @@ -1140,6 +1156,37 @@ def __post_init__(self): if self.expert_model_parallel_size > 1 and self.num_moe_experts is None: raise ValueError("num_moe_experts must be non None to use expert-parallel.") + if self.transformer_impl == "inference_optimized" and self.num_moe_experts is not None: + if self.expert_tensor_parallel_size > 1: + raise ValueError( + "Inference-optimized MoE layers does not support expert tensor parallelism." + ) + if self.moe_expert_capacity_factor is not None: + raise ValueError("Inference-optimized MoE layers only support dropless MoE ") + if self.moe_router_padding_for_quantization: + raise ValueError( + "Inference-optimized MoE layers do not support padded " + "routing map for quantization." + ) + if self.moe_router_dtype != "fp32" and self.moe_ggemm_inference_cg == "flashinfer": + raise ValueError( + "The FlashInfer CUDA-graphed MoE backend requires " + "--moe-router-dtype=fp32. Either set --moe-router-dtype=fp32 " + "or switch to --moe-ggemm-inference-cg=torch." + ) + if ( + self.gated_linear_unit + and self.cuda_graph_impl != "none" + and self.moe_ggemm_inference_cg == "flashinfer" + ): + raise ValueError( + "The FlashInfer CUDA-graphed MoE backend does not support gated " + "linear units (SwiGLU/GeGLU) due to differences in weight layouts " + "between the FlashInfer kernel and mcore. Either switch to the torch " + "backend (--moe-ggemm-inference-cg=torch), disable CUDA graphs " + "(--cuda-graph-impl=none), or use a non-gated activation (e.g. squared_relu)." + ) + if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError("num_moe_experts must be non-negative.") @@ -2143,6 +2190,30 @@ def __post_init__(self): "inference_fuse_tp_communication is only supported " "for inference_optimized transformer implementation." ) + assert ( + self.num_moe_experts is None + ), "--inference-fuse-tp-communication is not supported for MoE models." + + if self.inference_disable_triton_nvls_kernels: + assert self.transformer_impl == "inference_optimized", ( + "inference_disable_triton_nvls_kernels is only supported " + "for inference_optimized transformer implementation." + ) + + # Convert moe_ggemm_* strings to enums + _ggemm_valid = {b.name for b in MoEGroupedGemmBackend} + for field_name, allowed in [ + ("moe_ggemm_training", {"te"}), + ("moe_ggemm_inference_cg", {"flashinfer", "torch"}), + ("moe_ggemm_inference_no_cg", {"torch", "te"}), + ]: + val = getattr(self, field_name) + if isinstance(val, str): + if val not in allowed: + raise ValueError( + f"{field_name} must be one of {sorted(allowed)}, got '{val}'" + ) + object.__setattr__(self, field_name, MoEGroupedGemmBackend[val]) if self.batch_invariant_mode: assert ( diff --git a/megatron/core/utils.py b/megatron/core/utils.py index f176f3c2076..2094d60ae68 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -706,6 +706,44 @@ def _allocate(self, numel, dtype) -> torch.Tensor: required_bytes = numel * torch.tensor([], dtype=dtype).element_size() return self.symm_buffer[0:required_bytes].view(dtype).view(numel) + def maybe_get_tensors(self, tensor_specs, alignment=16): + """ + Pack multiple tensors contiguously in the symmetric buffer with alignment. + + Each tensor's starting offset is aligned to `alignment` bytes (default 16 + for 128-bit multimem access). + + Args: + tensor_specs: list of (numel, dtype) tuples. + alignment: byte alignment for each tensor's start offset (default 16). + + Returns: + {"handle": None, "tensors": None} if unavailable or insufficient space. + {"handle": symm_mem_hdl, "tensors": [(raw_byte_view, byte_offset), ...]} + on success, where raw_byte_view is a uint8 slice of the buffer. + """ + _NONE_RESULT = {"handle": None, "tensors": None} + if self.symm_mem_hdl is None: + return _NONE_RESULT + + # Compute aligned byte sizes and running offsets + slices = [] + current_offset = 0 + for numel, dtype in tensor_specs: + nbytes = numel * torch.tensor([], dtype=dtype).element_size() + aligned_nbytes = ((nbytes + alignment - 1) // alignment) * alignment + slices.append((current_offset, nbytes)) + current_offset += aligned_nbytes + + if not self._can_allocate(current_offset, torch.uint8): + return _NONE_RESULT + + tensors = [] + for offset, nbytes in slices: + tensors.append((self.symm_buffer[offset : offset + nbytes], offset)) + + return {"handle": self.symm_mem_hdl, "tensors": tensors} + def maybe_get_tensor(self, tensor_shape, dtype): """ Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape. diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index b3ffbc59b51..43d589d0f20 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1404,7 +1404,7 @@ def validate_args(args, defaults={}): assert args.inference_dynamic_batching_buffer_size_gb is not None assert args.inference_dynamic_batching_block_size % 256 == 0, "block size should be a multiple of 256" - if args.cuda_graph_impl == "local" and args.expert_model_parallel_size > 1: + if args.cuda_graph_impl == "local" and args.expert_model_parallel_size > 1 and args.transformer_impl != "inference_optimized": assert args.moe_pad_experts_for_cuda_graph_inference, \ "--moe-pad-experts-for-cuda-graph-inference must be set when using CUDA graphs with expert parallelism" diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py index d67c390068a..6d0ed756fc6 100644 --- a/tests/unit_tests/inference/test_batch_dimension_utils.py +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -31,7 +31,7 @@ def _generate_graphs(num_cuda_graphs, use_non_decode=True): tp_size=TP_SIZE, num_cuda_graphs=num_cuda_graphs, cuda_graph_max_tokens=MAX_REQUESTS, - cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + cuda_graph_mixed_prefill_request_count=min(MIXED_PREFILL_COUNT, MAX_REQUESTS), max_requests=MAX_REQUESTS, max_tokens=MAX_TOKENS, max_sequence_length=MAX_SEQ_LEN, @@ -50,7 +50,7 @@ def _match( decode_only_cuda_graphs=decode_only, explicit_chunked_prefill=explicit_chunked_prefill, ep_group=ep_group, - cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + smallest_non_decode_cuda_graph_size=MIXED_PREFILL_COUNT, ) diff --git a/tests/unit_tests/inference/test_moe_inference.py b/tests/unit_tests/inference/test_moe_inference.py new file mode 100644 index 00000000000..4d515db2d30 --- /dev/null +++ b/tests/unit_tests/inference/test_moe_inference.py @@ -0,0 +1,382 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for inference-optimized MoE components. + +Config is modeled after nanov3 (Nemotron-6 3B Hybrid MoE) with smaller +dimensions for fast unit test execution: +- squared_relu activation (not swiglu/gated) +- sigmoid router score function with expert bias +- topk=6, topk_scaling_factor=2.5 +- shared experts +""" + +import pytest +import torch + +from megatron.core.activations import squared_relu +from megatron.core.inference.communication.torch_symm_triton import are_tensors_nvls_eligible +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import is_te_min_version, is_torch_min_version +from megatron.training.initialize import _set_random_seed +from tests.unit_tests.test_utilities import Utils + +# Reusable skip decorators +requires_te = pytest.mark.skipif( + not is_te_min_version("1.7.0.dev0"), reason="Requires transformer-engine >= 1.7.0" +) +requires_torch_grouped_mm = pytest.mark.skipif( + not is_torch_min_version("2.10") or not hasattr(torch, '_grouped_mm'), + reason="Requires PyTorch >= 2.10 with torch._grouped_mm", +) + +# ────────────────────────────────────────────────────────────────────── +# NanoV3-like config (scaled down from 2688→128 hidden, 128→8 experts) +# ────────────────────────────────────────────────────────────────────── + +NANOV3_BASE = dict( + num_layers=1, + hidden_size=128, + ffn_hidden_size=128, + num_attention_heads=4, + num_query_groups=2, + num_moe_experts=8, + moe_ffn_hidden_size=128, + moe_router_topk=6, + moe_router_score_function="sigmoid", + moe_router_enable_expert_bias=True, + moe_router_topk_scaling_factor=2.5, + moe_shared_expert_intermediate_size=256, + moe_router_dtype='fp32', + moe_shared_expert_overlap=False, + moe_grouped_gemm=True, + moe_token_dispatcher_type="alltoall", + moe_aux_loss_coeff=0.01, + activation_func=squared_relu, + normalization="RMSNorm", + add_bias_linear=False, + bf16=True, + params_dtype=torch.bfloat16, + transformer_impl="inference_optimized", +) + + +def _make_base_config(**overrides): + """Create a TransformerConfig with nanov3-like defaults.""" + params = {**NANOV3_BASE, **overrides} + return TransformerConfig(**params) + + +# ────────────────────────────────────────────────────────────────────── +# InferenceTopKRouter +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.internal +class TestInferenceTopKRouter: + + @classmethod + def setup_class(cls): + Utils.initialize_model_parallel(1, 1) + _set_random_seed(seed_=123, data_parallel_random_init=False) + + @classmethod + def teardown_class(cls): + Utils.destroy_model_parallel() + + def _make_router(self, **config_overrides): + from megatron.core.transformer.moe.moe_utils import get_default_pg_collection + from megatron.core.transformer.moe.router import InferenceTopKRouter + + config = _make_base_config(**config_overrides) + return ( + InferenceTopKRouter(config=config, pg_collection=get_default_pg_collection()) + .cuda() + .to(torch.bfloat16) + ) + + def test_init_rejects_num_groups(self): + """InferenceTopKRouter requires moe_router_num_groups=None.""" + with pytest.raises(AssertionError, match="moe_router_num_groups"): + self._make_router(moe_router_num_groups=2) + + def test_config_rejects_non_fp32_router_dtype(self): + """inference_optimized config requires moe_router_dtype='fp32'.""" + with pytest.raises(ValueError, match="moe-router-dtype"): + _make_base_config( + transformer_impl="inference_optimized", add_qkv_bias=False, moe_router_dtype=None + ) + + @pytest.mark.parametrize("score_fn", ["none", "invalid"]) + def test_init_rejects_unsupported_score_function(self, score_fn): + """InferenceTopKRouter requires sigmoid or softmax score function.""" + with pytest.raises(AssertionError, match="moe_router_score_function"): + self._make_router( + moe_router_score_function=score_fn, moe_router_enable_expert_bias=False + ) + + @pytest.mark.parametrize("score_fn", ["sigmoid", "softmax"]) + def test_init_accepts_valid_score_function(self, score_fn): + """InferenceTopKRouter accepts sigmoid and softmax.""" + # Expert bias only valid with sigmoid; disable it for softmax + enable_bias = score_fn == "sigmoid" + router = self._make_router( + moe_router_score_function=score_fn, moe_router_enable_expert_bias=enable_bias + ) + assert router is not None + + def test_set_unset_inference_mode(self): + """Toggle is_inference_cuda_graphed_iteration flag.""" + router = self._make_router() + assert not router.is_inference_cuda_graphed_iteration + + router.set_inference_cuda_graphed_iteration() + assert router.is_inference_cuda_graphed_iteration + + router.unset_inference_cuda_graphed_iteration() + assert not router.is_inference_cuda_graphed_iteration + + def test_training_mode_forward_returns_sparse(self): + """In training mode, forward delegates to parent and returns sparse tensors.""" + router = self._make_router() + router.train() + num_tokens = 16 + num_experts = NANOV3_BASE["num_moe_experts"] + + input_tensor = torch.randn( + num_tokens, NANOV3_BASE["hidden_size"], device="cuda", dtype=torch.bfloat16 + ) + probs, routing_map = router(input_tensor) + + # Parent TopKRouter returns [num_tokens, num_experts] sparse routing_map + assert routing_map.shape == (num_tokens, num_experts) + + def test_inference_vs_training_selects_same_experts(self): + """Inference and training modes should select the same top-k experts.""" + router = self._make_router() + num_tokens = 16 + topk = NANOV3_BASE["moe_router_topk"] + + input_tensor = torch.randn( + num_tokens, NANOV3_BASE["hidden_size"], device="cuda", dtype=torch.bfloat16 + ) + + # Training mode: get routing_map (sparse) and extract top expert indices + router.train() + _, routing_map = router(input_tensor.clone()) + # routing_map is [num_tokens, num_experts] bool + training_experts = set() + for i in range(num_tokens): + experts_for_token = routing_map[i].nonzero(as_tuple=True)[0] + for e in experts_for_token: + training_experts.add((i, e.item())) + + # Inference mode: get top_indices (dense) + router.eval() + router.set_inference_cuda_graphed_iteration() + _, top_indices = router(input_tensor.clone()) + + inference_experts = set() + for i in range(num_tokens): + for k in range(topk): + inference_experts.add((i, top_indices[i, k].item())) + + # Same expert selections + assert training_experts == inference_experts + + +# ────────────────────────────────────────────────────────────────────── +# InferenceCUDAGraphTokenDispatcher +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.internal +class TestInferenceCUDAGraphTokenDispatcher: + + @classmethod + def setup_class(cls): + from megatron.core.parallel_state import _set_global_symmetric_memory_buffer + + Utils.initialize_model_parallel(1, 1, expert_model_parallel_size=Utils.world_size) + _set_random_seed(seed_=123, data_parallel_random_init=False) + _set_global_symmetric_memory_buffer() + + @classmethod + def teardown_class(cls): + from megatron.core.parallel_state import destroy_global_symmetric_memory_buffer + + destroy_global_symmetric_memory_buffer() + Utils.destroy_model_parallel() + + def _make_dispatcher(self, **config_overrides): + from megatron.core.transformer.moe.moe_utils import get_default_pg_collection + from megatron.core.transformer.moe.token_dispatcher_inference import ( + InferenceCUDAGraphTokenDispatcher, + ) + + config_overrides.setdefault("expert_model_parallel_size", Utils.world_size) + config = _make_base_config(**config_overrides) + num_local_experts = config.num_moe_experts // Utils.world_size + ep_rank = torch.distributed.get_rank() if Utils.world_size > 1 else 0 + local_expert_indices = [ep_rank * num_local_experts + i for i in range(num_local_experts)] + + return InferenceCUDAGraphTokenDispatcher( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=get_default_pg_collection(), + ) + + def test_init(self): + """Dispatcher can be constructed with nanov3-like config and EP=world_size.""" + dispatcher = self._make_dispatcher() + assert dispatcher.topk == NANOV3_BASE["moe_router_topk"] + assert dispatcher.ep_size == Utils.world_size + + def test_symmetric_memory_buffer_initialized(self): + """EP symmetric memory buffer is accessible after _set_global_symmetric_memory_buffer.""" + from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep + + buf = get_global_symmetric_memory_buffer_ep() + assert buf is not None + + @pytest.mark.parametrize("seed", [42, 123, 7]) + @pytest.mark.parametrize( + "num_local_tokens", + [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 40, + 48, + 56, + 64, + 72, + 80, + 88, + 96, + 104, + 112, + 120, + 128, + 136, + 144, + 152, + 160, + 168, + 176, + 184, + 192, + 200, + 208, + 216, + 224, + 232, + 240, + 248, + 256, + 272, + 288, + 304, + 320, + 336, + 352, + 368, + 384, + 400, + 416, + 432, + 448, + 464, + 480, + 496, + 512, + ], + ) + def test_cuda_graph_dispatch_combine(self, num_local_tokens, seed): + """Dispatch+combine can be captured in a CUDA graph and replayed. + Creates global buffers, shards per rank, and verifies: + - NVLS AllGather output matches the full globalwol buffer + - NVLS ReduceScatter output matches fp32-accumulated reference + All tensor byte sizes are 128-bit aligned for NVLS eligibility. + """ + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + dispatcher = self._make_dispatcher() + ep_size = dispatcher.ep_size + hidden_size = NANOV3_BASE["hidden_size"] + topk = NANOV3_BASE["moe_router_topk"] + num_experts = NANOV3_BASE["num_moe_experts"] + rank = torch.distributed.get_rank() if ep_size > 1 else 0 + num_global_tokens = num_local_tokens * ep_size + + # Create global buffers on rank 0 and broadcast to all ranks + global_hidden = torch.randn( + num_global_tokens, hidden_size, device="cuda", dtype=torch.bfloat16 + ) + global_probs = torch.randn(num_global_tokens, topk, device="cuda", dtype=torch.float32) + global_routing_map = torch.randint(0, num_experts, (num_global_tokens, topk), device="cuda") + torch.distributed.broadcast(global_hidden, src=0) + torch.distributed.broadcast(global_probs, src=0) + torch.distributed.broadcast(global_routing_map, src=0) + + # Each rank grabs their own shard + start = rank * num_local_tokens + end = start + num_local_tokens + static_hidden = global_hidden[start:end].contiguous() + static_probs = global_probs[start:end].contiguous() + static_routing_map = global_routing_map[start:end].contiguous() + + if not are_tensors_nvls_eligible(static_hidden, static_probs, static_routing_map): + pytest.skip( + "Tensors are not NVLS-eligible (need SM>=9 and each tensor's memory to be a multiple of 16 bytes)" + ) + + # 3 warmup iterations on a side stream + with torch.no_grad(): + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + dispatcher.routing_map = static_routing_map + d_hidden, d_probs = dispatcher.token_dispatch(static_hidden, static_probs) + d_hidden = d_hidden.clone() + d_probs = d_probs.clone() + dispatcher.routing_map = dispatcher.routing_map.clone() + dispatcher.token_combine(d_hidden.clone()) + torch.cuda.current_stream().wait_stream(s) + + # Capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + dispatcher.routing_map = static_routing_map + d_hidden, d_probs = dispatcher.token_dispatch(static_hidden, static_probs) + graph_hidden = d_hidden.clone() + graph_probs = d_probs.clone() + graph_routing_map = dispatcher.routing_map.clone() + graph_combined = dispatcher.token_combine(d_hidden.clone()) + + # Verify shapes: dispatch expands by ep_size, combine shrinks back + assert graph_hidden.shape == (num_global_tokens, hidden_size) + assert graph_probs.shape == (num_global_tokens, topk) + assert graph_combined.shape == (num_local_tokens, hidden_size) + + # Replay + graph.replay() + + # Verify AllGather: all gathered tensors should match global buffers + torch.testing.assert_close(graph_hidden, global_hidden, atol=0, rtol=0) + torch.testing.assert_close(graph_probs, global_probs, atol=0, rtol=0) + torch.testing.assert_close(graph_routing_map, global_routing_map, atol=0, rtol=0) + + # Verify ReduceScatter: all ranks have the same all-gathered data, so + # rank r gets ep_size * chunk_r. Compute reference in fp32 then downcast. + # Exact match (atol=0, rtol=0) is possible because the NVLS triton kernels + # accumulate in fp32 before writing bf16 output. + expected_combined = (global_hidden[start:end].float() * ep_size).bfloat16() + torch.testing.assert_close(graph_combined, expected_combined, atol=0, rtol=0) diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 93d66be0669..98c6ac63e0e 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -282,6 +282,8 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, + "inference_disable_torch_grouped_mm": False, + "inference_disable_triton_nvls_kernels": False, } # Fields to ignore entirely (ephemeral, environment-specific, very large). SKIP_FIELDS = set() diff --git a/tests/unit_tests/transformer/moe/test_moe_inference_utils.py b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py new file mode 100644 index 00000000000..02304b65018 --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_moe_inference_utils.py @@ -0,0 +1,613 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for megatron.core.transformer.moe.moe_inference_utils. + +Tests the compute_local_tokens_per_expert triton kernel against a PyTorch +reference implementation across various routing configurations. +""" + +import pytest +import torch + +from megatron.core.transformer.moe.moe_inference_utils import ( + compute_expert_offsets, + compute_local_tokens_per_expert, + permute_tokens, + unpermute_tokens, +) + + +def _make_routing_map(num_tokens, topk, num_total_experts, device="cuda"): + """Create a routing map with topk distinct experts per token (matching real router).""" + return torch.stack([ + torch.randperm(num_total_experts, device=device)[:topk] + for _ in range(num_tokens) + ]).to(torch.int32) + + +def _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts): + """PyTorch reference: count (token, k) pairs routed to each local expert.""" + flat = routing_map.flatten() + local_mask = (flat >= local_expert_start) & (flat < local_expert_start + num_local_experts) + local_ids = flat[local_mask] - local_expert_start + ref = torch.zeros(num_local_experts, dtype=torch.int32, device=routing_map.device) + ref.scatter_add_(0, local_ids.long(), torch.ones_like(local_ids, dtype=torch.int32)) + return ref + + +class TestComputeLocalTokensPerExpert: + """Tests for the _count_local_tokens_kernel triton kernel.""" + + @pytest.mark.internal + def test_random_routing(self): + """Random routing_map should match PyTorch reference.""" + torch.manual_seed(42) + num_tokens, topk = 128, 8 + num_total_experts, num_local_experts = 64, 8 + local_expert_start = 16 + + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_no_local_experts(self): + """All tokens routed to non-local experts should give all zeros.""" + num_tokens, topk = 64, 4 + num_local_experts = 8 + local_expert_start = 16 + + # Expert 0 is not in [16..23] + routing_map = torch.zeros(num_tokens, topk, dtype=torch.int32, device="cuda") + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_all_to_single_expert(self): + """All tokens routed to one local expert.""" + num_tokens, topk = 64, 4 + num_local_experts = 8 + local_expert_start = 16 + target_local_idx = 3 + + routing_map = torch.full( + (num_tokens, topk), + local_expert_start + target_local_idx, + dtype=torch.int32, + device="cuda", + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + expected[target_local_idx] = num_tokens * topk + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_single_token(self): + """Minimal case: 1 token, topk=1.""" + num_local_experts = 8 + local_expert_start = 16 + + routing_map = torch.tensor( + [[local_expert_start]], dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + expected[0] = 1 + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_uniform_distribution(self): + """Each token routes to all local experts exactly once (topk == num_local_experts).""" + num_tokens = 32 + num_local_experts = 8 + local_expert_start = 0 + + # Each row is [0, 1, 2, ..., 7] — one hit per local expert per token + routing_map = ( + torch.arange(num_local_experts, device="cuda", dtype=torch.int32) + .unsqueeze(0) + .expand(num_tokens, -1) + .contiguous() + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + expected = torch.full( + (num_local_experts,), num_tokens, dtype=torch.int32, device="cuda" + ) + assert torch.equal(result, expected) + + @pytest.mark.internal + def test_local_expert_start_at_zero(self): + """Local experts starting at global index 0.""" + torch.manual_seed(123) + num_tokens, topk = 256, 4 + num_total_experts = 32 + num_local_experts = 4 + local_expert_start = 0 + + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_local_experts_at_end(self): + """Local experts at the tail end of the global expert range.""" + torch.manual_seed(456) + num_tokens, topk = 256, 4 + num_total_experts = 32 + num_local_experts = 4 + local_expert_start = 28 # experts 28..31 + + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_large_batch(self): + """Larger batch to exercise multi-block histogram kernel.""" + torch.manual_seed(789) + num_tokens, topk = 2048, 8 + num_total_experts = 128 + num_local_experts = 16 + local_expert_start = 48 + + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_topk_one(self): + """topk=1: each token assigned to exactly one expert.""" + torch.manual_seed(101) + num_tokens = 512 + num_total_experts = 64 + num_local_experts = 8 + local_expert_start = 8 + + routing_map = torch.randint( + 0, num_total_experts, (num_tokens, 1), dtype=torch.int32, device="cuda" + ) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + @pytest.mark.internal + def test_non_power_of_two_tokens(self): + """Non-power-of-2 num_tokens to check masking at block boundaries.""" + torch.manual_seed(202) + num_tokens, topk = 137, 5 + num_total_experts = 32 + num_local_experts = 4 + local_expert_start = 12 + + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref) + + +class TestComputeExpertOffsets: + """Tests for the _prefix_sum_kernel triton kernel.""" + + @pytest.mark.internal + def test_basic_prefix_sum(self): + """Simple known input: [3, 1, 4, 1, 5] -> offsets [0, 3, 4, 8, 9].""" + tokens_per_expert = torch.tensor([3, 1, 4, 1, 5], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.tensor([0, 3, 4, 8, 9], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + assert torch.equal(counters, expected) + + @pytest.mark.internal + def test_single_expert(self): + """Single expert: offset is always 0.""" + tokens_per_expert = torch.tensor([42], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.tensor([0], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + assert torch.equal(counters, expected) + + @pytest.mark.internal + def test_all_zeros(self): + """All experts have zero tokens: offsets are all 0.""" + tokens_per_expert = torch.zeros(8, dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.zeros(8, dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + + @pytest.mark.internal + def test_one_hot(self): + """Only one expert has tokens.""" + tokens_per_expert = torch.tensor([0, 0, 10, 0, 0], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + expected = torch.tensor([0, 0, 0, 10, 10], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected) + + @pytest.mark.internal + def test_counters_are_independent_copy(self): + """Mutating counters should not affect offsets.""" + tokens_per_expert = torch.tensor([2, 3, 5], dtype=torch.int32, device="cuda") + offsets, counters = compute_expert_offsets(tokens_per_expert) + + # Simulate what the permute kernel does + counters[0] += 1 + counters[1] += 2 + + expected_offsets = torch.tensor([0, 2, 5], dtype=torch.int32, device="cuda") + assert torch.equal(offsets, expected_offsets), "offsets should be unmodified" + + @pytest.mark.internal + def test_end_to_end_with_histogram(self): + """Chain histogram -> prefix sum and verify consistency.""" + torch.manual_seed(99) + num_tokens, topk = 128, 8 + num_total_experts = 64 + num_local_experts = 8 + local_expert_start = 16 + + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + + tokens_per_expert = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + offsets, counters = compute_expert_offsets(tokens_per_expert) + + # Verify: offsets[i] == sum(tokens_per_expert[0:i]) + ref_offsets = torch.zeros_like(tokens_per_expert) + ref_offsets[1:] = torch.cumsum(tokens_per_expert[:-1], dim=0) + assert torch.equal(offsets, ref_offsets) + + # Verify: last offset + last count == total + assert offsets[-1] + tokens_per_expert[-1] == tokens_per_expert.sum() + + +class TestComputeLocalTokensPerExpertSweep: + """Exhaustive parametrized sweep for _count_local_tokens_kernel.""" + + @pytest.mark.internal + @pytest.mark.parametrize("num_tokens", [1, 7, 32, 137, 256, 512, 2048]) + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("num_local_experts", [1, 3, 4, 8, 16]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_sweep(self, num_tokens, topk, num_local_experts, num_total_experts): + """Sweep across token counts, topk, and expert configurations.""" + if num_local_experts > num_total_experts: + pytest.skip("num_local_experts > num_total_experts") + if topk > num_total_experts: + pytest.skip("topk > num_total_experts") + + torch.manual_seed(num_tokens * 1000 + topk * 100 + num_local_experts) + local_expert_start = (num_total_experts - num_local_experts) // 2 + + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + + result = compute_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + ref = _reference_local_tokens_per_expert(routing_map, local_expert_start, num_local_experts) + assert torch.equal(result, ref), ( + f"Mismatch for num_tokens={num_tokens}, topk={topk}, " + f"num_local_experts={num_local_experts}, num_total_experts={num_total_experts}" + ) + + +class TestComputeExpertOffsetsSweep: + """Exhaustive parametrized sweep for _prefix_sum_kernel.""" + + @pytest.mark.internal + @pytest.mark.parametrize("num_local_experts", [1, 2, 3, 4, 5, 7, 8, 13, 16, 32, 64]) + @pytest.mark.parametrize("max_count", [0, 1, 10, 100, 1000]) + def test_sweep(self, num_local_experts, max_count): + """Sweep across expert counts and token magnitudes.""" + torch.manual_seed(num_local_experts * 100 + max_count) + + if max_count == 0: + tokens_per_expert = torch.zeros(num_local_experts, dtype=torch.int32, device="cuda") + else: + tokens_per_expert = torch.randint( + 0, max_count + 1, (num_local_experts,), dtype=torch.int32, device="cuda" + ) + + offsets, counters = compute_expert_offsets(tokens_per_expert) + + # Reference: exclusive prefix sum via torch.cumsum + ref_offsets = torch.zeros_like(tokens_per_expert) + if num_local_experts > 1: + ref_offsets[1:] = torch.cumsum(tokens_per_expert[:-1], dim=0) + + assert torch.equal(offsets, ref_offsets), ( + f"Mismatch for num_local_experts={num_local_experts}, max_count={max_count}\n" + f" tokens_per_expert={tokens_per_expert.tolist()}\n" + f" expected={ref_offsets.tolist()}\n" + f" got={offsets.tolist()}" + ) + assert torch.equal(counters, ref_offsets), "counters should equal offsets initially" + + +class TestPermuteTokens: + """Tests for the _permute_tokens_kernel triton kernel.""" + + def _run_permute( + self, num_tokens, topk, hidden_dim, num_total_experts, num_local_experts, local_expert_start, + seed=42, + ): + """Run the full histogram -> prefix_sum -> permute pipeline and verify correctness. + + Checks: + 1. Each permuted row's hidden state matches hidden_states[source_token_indices[i]] + 2. Each expert group contains exactly the right set of source token indices + 3. Total routed count matches tokens_per_expert.sum() + """ + if num_local_experts > num_total_experts: + pytest.skip("num_local_experts > num_total_experts") + if num_total_experts % num_local_experts != 0: + pytest.skip("num_total_experts not divisible by num_local_experts") + if local_expert_start + num_local_experts > num_total_experts: + pytest.skip("local expert range exceeds num_total_experts") + if topk > num_total_experts: + pytest.skip("topk > num_total_experts") + + torch.manual_seed(seed) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + hidden_states = torch.randn(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + probs = torch.randn(num_tokens, topk, dtype=torch.float32, device="cuda") + + # Triton pipeline + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + _, atomic_counters = compute_expert_offsets(tokens_per_expert) + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, routing_map, atomic_counters, + local_expert_start, num_local_experts, + ) + + total_routed = tokens_per_expert.sum().item() + + # Check 1: every permuted row matches its source token + for i in range(total_routed): + src = source_token_indices[i].item() + assert 0 <= src < num_tokens, ( + f"Row {i}: source_token_indices={src} out of bounds [0, {num_tokens})" + ) + assert torch.equal(permuted_hidden[i], hidden_states[src]), ( + f"Row {i}: permuted_hidden doesn't match hidden_states[{src}]" + ) + # We don't store k_idx in the output, so we can't look up the exact + # prob. Instead, verify it's one of the topk probs for this token. + assert permuted_probs[i] in probs[src], ( + f"Row {i}: permuted_probs={permuted_probs[i].item()} not in probs[{src}]" + ) + + # Check 2: each expert group has the right set of source token indices + offset = 0 + for e in range(num_local_experts): + count = tokens_per_expert[e].item() + expert_global_id = local_expert_start + e + + # Expected: all token indices that have this expert in their routing_map + expected_tokens = set() + for t in range(num_tokens): + for k in range(topk): + if routing_map[t, k].item() == expert_global_id: + expected_tokens.add(t) + + # Actual: source token indices in this group + actual_tokens = set(source_token_indices[offset:offset + count].tolist()) + assert actual_tokens == expected_tokens, ( + f"Expert {e}: token set mismatch.\n" + f" expected={sorted(expected_tokens)}\n" + f" actual={sorted(actual_tokens)}" + ) + + offset += count + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256, 1024]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_basic(self, topk, hidden_dim, num_total_experts): + """Basic test with moderate sizes, sweep topk, hidden_dim, num_total_experts.""" + self._run_permute( + num_tokens=64, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=4, local_expert_start=8, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256, 1024]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_single_token(self, topk, hidden_dim, num_total_experts): + """Single token, sweep topk, hidden_dim, num_total_experts.""" + self._run_permute( + num_tokens=1, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=8, local_expert_start=0, + ) + + @pytest.mark.internal + def test_no_local_hits(self): + """All tokens routed to non-local experts.""" + torch.manual_seed(0) + num_tokens, topk, hidden_dim = 32, 4, 64 + num_local_experts = 4 + local_expert_start = 100 # way above any expert ID in [0, 16) + + routing_map = _make_routing_map(num_tokens, topk, 16) + hidden_states = torch.randn(num_tokens, hidden_dim, device="cuda") + probs = torch.randn(num_tokens, topk, device="cuda") + + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + _, atomic_counters = compute_expert_offsets(tokens_per_expert) + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, routing_map, atomic_counters, + local_expert_start, num_local_experts, + ) + + assert tokens_per_expert.sum().item() == 0 + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256, 1024]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_all_to_one_expert(self, topk, hidden_dim, num_total_experts): + """All tokens routed to a single local expert — max atomic contention.""" + self._run_permute( + num_tokens=32, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=1, local_expert_start=0, + seed=99, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 3, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [100, 200, 350]) + @pytest.mark.parametrize("num_total_experts", [8, 16, 32, 64, 128]) + def test_non_power_of_two_hidden_dim(self, topk, hidden_dim, num_total_experts): + """Non-power-of-2 dims, sweep topk, hidden_dim, num_total_experts.""" + self._run_permute( + num_tokens=32, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=4, local_expert_start=4, + ) + + +class TestUnpermuteTokens: + """Tests for the _unpermute_tokens_kernel triton kernel.""" + + def _run_end_to_end( + self, num_tokens, topk, hidden_dim, num_total_experts, num_local_experts, + local_expert_start, seed=42, + ): + """Run the full permute -> unpermute pipeline and compare to PyTorch reference. + + The reference computes: + output[t, :] = sum over local experts e assigned to token t: + expert_output[permuted_row_for(t, e), :] * prob[t, k_for(e)] + We simulate expert_output as random data (as if the grouped GEMM ran). + """ + if num_local_experts > num_total_experts: + pytest.skip("num_local_experts > num_total_experts") + if num_total_experts % num_local_experts != 0: + pytest.skip("num_total_experts not divisible by num_local_experts") + if local_expert_start + num_local_experts > num_total_experts: + pytest.skip("local expert range exceeds num_total_experts") + if topk > num_total_experts: + pytest.skip("topk > num_total_experts") + + torch.manual_seed(seed) + routing_map = _make_routing_map(num_tokens, topk, num_total_experts) + hidden_states = torch.randn(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + probs = torch.randn(num_tokens, topk, dtype=torch.float32, device="cuda").abs() + + # Permute + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + _, atomic_counters = compute_expert_offsets(tokens_per_expert) + permuted_hidden, permuted_probs, source_token_indices = permute_tokens( + hidden_states, probs, routing_map, atomic_counters, + local_expert_start, num_local_experts, + ) + + total_routed = tokens_per_expert.sum().item() + + # Simulate expert output (random, as if grouped GEMM ran on permuted_hidden) + expert_output = torch.randn( + permuted_hidden.shape[0], hidden_dim, dtype=torch.float32, device="cuda" + ) + + # Triton unpermute — pass atomic_counters so kernel reads total_routed from GPU + result = unpermute_tokens( + expert_output, permuted_probs, source_token_indices, num_tokens, + atomic_counters, + ) + + # PyTorch reference: scatter-add expert_output[i] * prob[i] to source token + ref = torch.zeros(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + for i in range(total_routed): + src = source_token_indices[i].item() + ref[src] += expert_output[i] * permuted_probs[i] + + # Tolerance depends on dtype: bf16 atomic adds have ~0.01 precision + atol = 1e-2 if expert_output.dtype == torch.bfloat16 else 1e-5 + assert torch.allclose(result, ref, atol=atol), ( + f"Unpermute mismatch: max diff={torch.max(torch.abs(result - ref)).item()}" + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + def test_basic(self, topk, hidden_dim): + """Basic unpermute test, sweep topk and hidden_dim.""" + self._run_end_to_end( + num_tokens=64, topk=topk, hidden_dim=hidden_dim, + num_total_experts=32, num_local_experts=4, local_expert_start=8, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + def test_single_token(self, topk, hidden_dim): + """Single token unpermute, sweep topk and hidden_dim.""" + self._run_end_to_end( + num_tokens=1, topk=topk, hidden_dim=hidden_dim, + num_total_experts=32, num_local_experts=8, local_expert_start=0, + ) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_all_to_one_expert(self, topk, hidden_dim, num_total_experts): + """All tokens to one local expert — tests single-expert accumulation.""" + self._run_end_to_end( + num_tokens=32, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=1, local_expert_start=0, + seed=99, + ) + + @pytest.mark.internal + def test_no_local_hits(self): + """No local tokens — unpermute should return all zeros.""" + torch.manual_seed(0) + num_tokens, hidden_dim = 32, 64 + + # No permuted rows — atomic_counters with a single zero entry + expert_output = torch.empty(0, hidden_dim, dtype=torch.float32, device="cuda") + permuted_probs = torch.empty(0, dtype=torch.float32, device="cuda") + source_token_indices = torch.empty(0, dtype=torch.int32, device="cuda") + atomic_counters = torch.zeros(1, dtype=torch.int32, device="cuda") + + result = unpermute_tokens( + expert_output, permuted_probs, source_token_indices, num_tokens, atomic_counters, + ) + + expected = torch.zeros(num_tokens, hidden_dim, dtype=torch.float32, device="cuda") + assert torch.equal(result, expected) + + @pytest.mark.internal + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize("hidden_dim", [64, 100, 128, 256]) + @pytest.mark.parametrize("num_total_experts", [8, 32, 64, 128]) + def test_sweep(self, topk, hidden_dim, num_total_experts): + """Sweep across topk, hidden_dim, num_total_experts.""" + self._run_end_to_end( + num_tokens=64, topk=topk, hidden_dim=hidden_dim, + num_total_experts=num_total_experts, num_local_experts=4, local_expert_start=8, + seed=topk * 1000 + hidden_dim, + )