TE AITER groupedgemm optimizations in Megatron #109
+41
−6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
This PR enhances the Mixture of Experts (MoE) implementation in Megatron-LM to support GPU-based token processing for grouped GEMM operations. The changes enable improved performance when using Transformer Engine's grouped GEMM functionality with the Triton backend, while maintaining backward compatibility with existing configurations and older Transformer Engine versions. Maintaining the m_splits as a GPU tensor prevents any further redundant device to host to device memcpys and in turn improves throughput.
Performance difference:
https://github.com/ROCm/frameworks-internal/issues/13792#issuecomment-3849739037
Technical Details
The implementation introduces GPU token handling across three main components:
Token Dispatcher Enhancement (commit 218891725)
MoEAlltoAllTokenDispatcherinmegatron/core/transformer/moe/token_dispatcher.pyto tracktokens_per_expert_gpualongside the existing CPU-basedtokens_per_expertTEGroupedMLPinmegatron/core/transformer/moe/experts.pyto accept tuple inputs containing both CPU and GPU token tensorsmegatron/core/extensions/transformer_engine.pyto process GPU token inputsCompatibility Checking (commit 33485e7bc)
class_has_method_param()utility function to dynamically verify if Transformer Engine'sGroupedLinear.forwardsupports them_splits_tensorparameterm_splits_tensorparameterConditional GPU Processing (commit e782b8cf2)
NVTE_USE_GROUPED_GEMM_TRITONto conditionally enable GPU token processingMoEAlltoAllTokenDispatcher.forward()to return tuples containing both CPU and GPU tensors when the environment variable is settorch.int32for compatibility with Transformer EngineTest Plan
NVTE_USE_GROUPED_GEMM_TRITON=1environment variableTest Result
The changes successfully enable GPU-based token processing when the feature is enabled, while maintaining full backward compatibility. Warning messages are properly displayed when using older Transformer Engine versions, guiding users to upgrade for optimal performance.
Submission Checklist