Skip to content

Mingyyan/dev/hybrid#510

Open
Mingyuyang-1 wants to merge 39 commits intomainfrom
mingyyan/dev/hybrid
Open

Mingyyan/dev/hybrid#510
Mingyuyang-1 wants to merge 39 commits intomainfrom
mingyyan/dev/hybrid

Conversation

@Mingyuyang-1
Copy link

Add support to Mamba and hybrid MLA-Mamba models in Primus-LM with Megatron backends

Four models added:

  • Mamba_370M
  • Zebra_Llama_1B
  • Zebra_Llama_3B
  • Zebra_Llama_8B

To support Mamba and Mamba-based hybrid models, we add new layer_specs and model block under Primus/primus/backends/megatron/core/models/hybrid

alfuyao1986 and others added 30 commits October 2, 2025 00:40
Copilot AI review requested due to automatic review settings January 23, 2026 21:07

# Ensure that the tensor passed between pipeline parallel stages is
# viewless. See related notes in TransformerBlock and TransformerLayer
output = make_viewless_tensor(
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols as LayerSymbols
from megatron.core.ssm.mamba_hybrid_layer_allocation import allocate_layers
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for Mamba and hybrid MLA-Mamba models in Primus-LM with both Megatron and TorchTitan backends. The changes enable training of four new model variants: Mamba_370M, Zebra_Llama_1B, Zebra_Llama_3B, and Zebra_Llama_8B.

Changes:

  • Added Mamba and hybrid model support in Megatron backend with new layer specifications and model blocks
  • Enhanced TorchTitan backend with improved attention patching, MoE grouped MM support, and FP8 quantization
  • Updated configuration files to support new quantization structure and model-specific settings

Reviewed changes

Copilot reviewed 67 out of 69 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
primus/backends/megatron/core/models/hybrid/* New hybrid stack and layer specs for Mamba+MLA models
primus/modules/trainer/megatron/*.py Model type detection and Mamba-specific forward pass handling
primus/core/utils/import_utils.py Model provider resolution for GPT and Mamba models
primus/configs/models/megatron/* Configuration files for new Mamba and Zebra models
primus/backends/torchtitan/models/* Attention model updates for llama3, llama4, and deepseek_v3
primus/modules/trainer/torchtitan/pre_trainer.py Enhanced patching logic for MoE, attention, and quantization
primus/configs/modules/torchtitan/* Restructured quantization config and new turbo settings
examples/torchtitan/configs/MI300X/* Updated training configs with new quantization structure
examples/run_pretrain.sh Added Primus Turbo rebuild capability
Comments suppressed due to low confidence (6)

primus/configs/models/megatron/zebra_llama_1B.yaml:1

  • The comment incorrectly states 'Zebra Llama 8B configuration' when this is the 1B model configuration file. This should be 'Zebra Llama 1B configuration'.
    primus/configs/models/megatron/zebra_llama_3B.yaml:1
  • The comment incorrectly states 'Zebra Llama 8B configuration' when this is the 3B model configuration file. This should be 'Zebra Llama 3B configuration'.
    primus/backends/torchtitan/models/moe/moe.py:1
  • Corrected spelling of 'tyr' to 'try'.
    primus/modules/trainer/torchtitan/patch_utils.py:1
  • Corrected spelling of 'PrimusPath' to 'PrimusPatch' to match the prefix used elsewhere in the file.
    primus/modules/trainer/torchtitan/patch_utils.py:1
  • Corrected spelling of 'PrimusPath' to 'PrimusPatch' to match the prefix used elsewhere in the file.
    primus/modules/trainer/torchtitan/pre_trainer.py:1
  • This commented-out error message provides valuable context for the actual error message below it. Consider removing this commented line or add a comment explaining why it's kept for reference.

module=MambaMixer,
params={
"expand": 1,
"d_conv": 4,
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter name 'd_conv' is unclear. Consider using a more descriptive name like 'conv_dimension' or add a comment explaining what 'd_conv' represents.

Suggested change
"d_conv": 4,
"d_conv": 4, # Convolution dimension (kernel size) used in the Mamba mixer

Copilot uses AI. Check for mistakes.
fused_padded_mla_attention: false

multi_latent_attention: false
#multi_latent_attention: true
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out configuration line has inconsistent indentation (7 spaces instead of the standard indentation). If this is meant to be uncommented for use, fix the indentation to match the surrounding code.

Suggested change
#multi_latent_attention: true
#multi_latent_attention: true

Copilot uses AI. Check for mistakes.
Comment on lines +256 to +261
fp8_str = config.fp8.lower()

if fp8_str == "e4m3":
fp8_format = FP8Format.E4M3
elif fp8_str == "hybrid":
fp8_format = FP8Format.HYBRID
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code calls .lower() on config.fp8 but then compares with lowercase strings. If config.fp8 could be a non-string type (e.g., boolean for 'hybrid'), this will fail. Add a type check or ensure config.fp8 is always a string before calling .lower().

Copilot uses AI. Check for mistakes.
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device='cuda',
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hardcoding 'cuda' as the device may cause issues in multi-device environments. Consider using hidden_states.device or a device parameter from the config to maintain device consistency.

Suggested change
device='cuda',
device=hidden_states.device,

Copilot uses AI. Check for mistakes.
@@ -0,0 +1,67 @@
#!/bin/bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't put this script in the root path of the Primus repo

FROM ${BASE_IMAGE}
# Base image
# FROM docker.io/rocm/megatron-lm:v25.9_gfx942
FROM docker.io/rocm/pyt-megatron-lm-jax-nightly-private:pytorch_rocm7.0_20251024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Primus use public released docker image as the base image. The main branch use v25.10, please try to rebase or merge the main branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants