Skip to content

Conversation

@sudhu2k
Copy link
Collaborator

@sudhu2k sudhu2k commented Feb 4, 2026

Motivation

This PR enhances the Mixture of Experts (MoE) implementation in Megatron-LM to support GPU-based token processing for grouped GEMM operations. The changes enable improved performance when using Transformer Engine's grouped GEMM functionality with the Triton backend, while maintaining backward compatibility with existing configurations and older Transformer Engine versions. Maintaining the m_splits as a GPU tensor prevents any further redundant device to host to device memcpys and in turn improves throughput.

Performance difference:

https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3849739037

Technical Details

The implementation introduces GPU token handling across three main components:

  1. Token Dispatcher Enhancement (commit 218891725)

    • Modified MoEAlltoAllTokenDispatcher in megatron/core/transformer/moe/token_dispatcher.py to track tokens_per_expert_gpu alongside the existing CPU-based tokens_per_expert
    • Updated TEGroupedMLP in megatron/core/transformer/moe/experts.py to accept tuple inputs containing both CPU and GPU token tensors
    • Enhanced forward methods in megatron/core/extensions/transformer_engine.py to process GPU token inputs
  2. Compatibility Checking (commit 33485e7bc)

    • Added class_has_method_param() utility function to dynamically verify if Transformer Engine's GroupedLinear.forward supports the m_splits_tensor parameter
    • Implemented version-aware parameter passing with user warnings when using older Transformer Engine versions lacking the m_splits_tensor parameter
    • Ensures compatibility with Transformer Engine commit 2776c33 which introduced GPU tensor support
  3. Conditional GPU Processing (commit e782b8cf2)

    • Added environment variable check for NVTE_USE_GROUPED_GEMM_TRITON to conditionally enable GPU token processing
    • Updated return signatures in MoEAlltoAllTokenDispatcher.forward() to return tuples containing both CPU and GPU tensors when the environment variable is set
    • Ensures GPU tensors are converted to torch.int32 for compatibility with Transformer Engine

Test Plan

  • Tested with and without NVTE_USE_GROUPED_GEMM_TRITON=1 environment variable
  • Verified backward compatibility with existing MoE training scripts
  • Validated functionality with both older and newer versions of Transformer Engine
  • Ensured proper tensor type conversions and parameter passing through the MoE pipeline

Test Result

The changes successfully enable GPU-based token processing when the feature is enabled, while maintaining full backward compatibility. Warning messages are properly displayed when using older Transformer Engine versions, guiding users to upgrade for optimal performance.

Submission Checklist

sugovind added 3 commits January 13, 2026 21:07
…ken handling

- Updated forward methods in TEGroupedMLP and transformer_engine.py to accept and process GPU token inputs.
- Modified MoEAlltoAllTokenDispatcher to manage tokens_per_expert for GPU, ensuring compatibility with new tensor structures.
- Introduced `class_has_method_param` to verify if a class method contains a specified parameter.
- Updated the `forward` method in `TEGroupedMLP` to check for the `m_splits_tensor` parameter in the parent class's `forward` method, providing a warning if it's missing.
- Ensured compatibility with Transformer Engine versioning for improved performance in MoE GroupedGEMM.
… processing

- Added an environment variable check for `NVTE_USE_GROUPED_GEMM_TRITON` to manage `tokens_per_expert_gpu` conversion.
- Updated return values in the `forward` method to ensure compatibility with new GPU handling logic.
@sudhu2k sudhu2k self-assigned this Feb 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant