Conversation
|
|
||
| # Ensure that the tensor passed between pipeline parallel stages is | ||
| # viewless. See related notes in TransformerBlock and TransformerLayer | ||
| output = make_viewless_tensor( |
| from megatron.core.inference.contexts import BaseInferenceContext | ||
| from megatron.core.process_groups_config import ProcessGroupCollection | ||
| from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols | ||
| from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers |
| from megatron.core.transformer.identity_op import IdentityOp | ||
| from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add | ||
| from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec | ||
| from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules |
There was a problem hiding this comment.
Pull request overview
This PR adds support for Mamba and hybrid MLA-Mamba models in Primus-LM with both Megatron and TorchTitan backends. The changes enable training of four new model variants: Mamba_370M, Zebra_Llama_1B, Zebra_Llama_3B, and Zebra_Llama_8B.
Changes:
- Added Mamba and hybrid model support in Megatron backend with new layer specifications and model blocks
- Enhanced TorchTitan backend with improved attention patching, MoE grouped MM support, and FP8 quantization
- Updated configuration files to support new quantization structure and model-specific settings
Reviewed changes
Copilot reviewed 67 out of 69 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| primus/backends/megatron/core/models/hybrid/* | New hybrid stack and layer specs for Mamba+MLA models |
| primus/modules/trainer/megatron/*.py | Model type detection and Mamba-specific forward pass handling |
| primus/core/utils/import_utils.py | Model provider resolution for GPT and Mamba models |
| primus/configs/models/megatron/* | Configuration files for new Mamba and Zebra models |
| primus/backends/torchtitan/models/* | Attention model updates for llama3, llama4, and deepseek_v3 |
| primus/modules/trainer/torchtitan/pre_trainer.py | Enhanced patching logic for MoE, attention, and quantization |
| primus/configs/modules/torchtitan/* | Restructured quantization config and new turbo settings |
| examples/torchtitan/configs/MI300X/* | Updated training configs with new quantization structure |
| examples/run_pretrain.sh | Added Primus Turbo rebuild capability |
Comments suppressed due to low confidence (6)
primus/configs/models/megatron/zebra_llama_1B.yaml:1
- The comment incorrectly states 'Zebra Llama 8B configuration' when this is the 1B model configuration file. This should be 'Zebra Llama 1B configuration'.
primus/configs/models/megatron/zebra_llama_3B.yaml:1 - The comment incorrectly states 'Zebra Llama 8B configuration' when this is the 3B model configuration file. This should be 'Zebra Llama 3B configuration'.
primus/backends/torchtitan/models/moe/moe.py:1 - Corrected spelling of 'tyr' to 'try'.
primus/modules/trainer/torchtitan/patch_utils.py:1 - Corrected spelling of 'PrimusPath' to 'PrimusPatch' to match the prefix used elsewhere in the file.
primus/modules/trainer/torchtitan/patch_utils.py:1 - Corrected spelling of 'PrimusPath' to 'PrimusPatch' to match the prefix used elsewhere in the file.
primus/modules/trainer/torchtitan/pre_trainer.py:1 - This commented-out error message provides valuable context for the actual error message below it. Consider removing this commented line or add a comment explaining why it's kept for reference.
| module=MambaMixer, | ||
| params={ | ||
| "expand": 1, | ||
| "d_conv": 4, |
There was a problem hiding this comment.
The parameter name 'd_conv' is unclear. Consider using a more descriptive name like 'conv_dimension' or add a comment explaining what 'd_conv' represents.
| "d_conv": 4, | |
| "d_conv": 4, # Convolution dimension (kernel size) used in the Mamba mixer |
| fused_padded_mla_attention: false | ||
|
|
||
| multi_latent_attention: false | ||
| #multi_latent_attention: true |
There was a problem hiding this comment.
This commented-out configuration line has inconsistent indentation (7 spaces instead of the standard indentation). If this is meant to be uncommented for use, fix the indentation to match the surrounding code.
| #multi_latent_attention: true | |
| #multi_latent_attention: true |
| fp8_str = config.fp8.lower() | ||
|
|
||
| if fp8_str == "e4m3": | ||
| fp8_format = FP8Format.E4M3 | ||
| elif fp8_str == "hybrid": | ||
| fp8_format = FP8Format.HYBRID |
There was a problem hiding this comment.
The code calls .lower() on config.fp8 but then compares with lowercase strings. If config.fp8 could be a non-string type (e.g., boolean for 'hybrid'), this will fail. Add a type check or ensure config.fp8 is always a string before calling .lower().
| sequence_len_offset = torch.tensor( | ||
| [inference_context.sequence_len_offset] * current_batch_size, | ||
| dtype=torch.int32, | ||
| device='cuda', |
There was a problem hiding this comment.
Hardcoding 'cuda' as the device may cause issues in multi-device environments. Consider using hidden_states.device or a device parameter from the config to maintain device consistency.
| device='cuda', | |
| device=hidden_states.device, |
| @@ -0,0 +1,67 @@ | |||
| #!/bin/bash | |||
There was a problem hiding this comment.
Please don't put this script in the root path of the Primus repo
| FROM ${BASE_IMAGE} | ||
| # Base image | ||
| # FROM docker.io/rocm/megatron-lm:v25.9_gfx942 | ||
| FROM docker.io/rocm/pyt-megatron-lm-jax-nightly-private:pytorch_rocm7.0_20251024 |
There was a problem hiding this comment.
Primus use public released docker image as the base image. The main branch use v25.10, please try to rebase or merge the main branch.
Add support to Mamba and hybrid MLA-Mamba models in Primus-LM with Megatron backends
Four models added:
To support Mamba and Mamba-based hybrid models, we add new layer_specs and model block under
Primus/primus/backends/megatron/core/models/hybrid