Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/megatron/bridge/models/qwen/qwen_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ class Qwen3MoEModelProvider(GPTModelProvider):
layernorm_epsilon: float = 1e-6
rotary_base: float = 1000000.0
position_embedding_type: str = "rope"
yarn_rotary_scaling_factor: float | None = None
yarn_original_max_position_embeddings: int | None = None
yarn_beta_fast: float | None = None
yarn_beta_slow: float | None = None
yarn_mscale: float | None = None
yarn_mscale_all_dim: float | None = None
yarn_correction_range_round_to_int: bool = False
autocast_dtype: torch.dtype = torch.bfloat16
params_dtype: torch.dtype = torch.bfloat16
bf16: bool = True
Expand Down
8 changes: 8 additions & 0 deletions src/megatron/bridge/training/model_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,14 @@ def load_megatron_model(
if hasattr(model_cfg, key) and value is not None:
setattr(model_cfg, key, value)

# flex dispatcher requires TP*EP > 1; fall back to alltoall for single-GPU export
if (
getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex"
and getattr(model_cfg, "tensor_model_parallel_size", 1) * getattr(model_cfg, "expert_model_parallel_size", 1)
<= 1
):
model_cfg.moe_token_dispatcher_type = "alltoall"

return build_and_load_model(
checkpoint_path, model_cfg, model_type, mlm_args, return_state_dict, use_cpu_init, skip_temp_dist_context
)
Expand Down
40 changes: 40 additions & 0 deletions tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,46 @@ def test_qwen3_moe_model_provider_qk_layernorm(self):
# Qwen3 MoE uses QK layernorm unlike Qwen2
assert provider.qk_layernorm is True

def test_qwen3_moe_model_provider_with_yarn_rope(self):
"""Test Qwen3MoEModelProvider with YaRN RoPE configuration."""
provider = Qwen3MoEModelProvider(
num_layers=32,
hidden_size=4096,
num_attention_heads=32,
position_embedding_type="yarn",
yarn_rotary_scaling_factor=4.0,
yarn_original_max_position_embeddings=32768,
yarn_beta_fast=32.0,
yarn_beta_slow=1.0,
yarn_mscale=1.0,
yarn_mscale_all_dim=1.0,
yarn_correction_range_round_to_int=True,
)

assert provider.yarn_rotary_scaling_factor == 4.0
assert provider.yarn_original_max_position_embeddings == 32768
assert provider.yarn_beta_fast == 32.0
assert provider.yarn_beta_slow == 1.0
assert provider.yarn_mscale == 1.0
assert provider.yarn_mscale_all_dim == 1.0
assert provider.yarn_correction_range_round_to_int is True
Comment on lines +158 to +180
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use position_embedding_type="yarn" in the YaRN test case.

Line 164 sets "rope", so this test does not actually exercise the YaRN mode it is named for.

Proposed fix
-            position_embedding_type="rope",
+            position_embedding_type="yarn",
             yarn_rotary_scaling_factor=4.0,
@@
         assert provider.yarn_rotary_scaling_factor == 4.0
+        assert provider.position_embedding_type == "yarn"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py` around lines
158 - 180, The test test_qwen3_moe_model_provider_with_yarn_rope is meant to
exercise YaRN mode but currently constructs Qwen3MoEModelProvider with
position_embedding_type="rope"; change the instantiation to use
position_embedding_type="yarn" so the Qwen3MoEModelProvider(...) call actually
configures YaRN-specific fields (yarn_rotary_scaling_factor,
yarn_original_max_position_embeddings, etc.) and the subsequent asserts validate
the YaRN behavior.


def test_qwen3_moe_model_provider_yarn_defaults_none(self):
"""Test Qwen3MoEModelProvider YaRN fields default to None/False."""
provider = Qwen3MoEModelProvider(
num_layers=32,
hidden_size=4096,
num_attention_heads=32,
)

assert provider.yarn_rotary_scaling_factor is None
assert provider.yarn_original_max_position_embeddings is None
assert provider.yarn_beta_fast is None
assert provider.yarn_beta_slow is None
assert provider.yarn_mscale is None
assert provider.yarn_mscale_all_dim is None
assert provider.yarn_correction_range_round_to_int is False

def test_qwen3_moe_model_provider_dtype_configuration(self):
"""Test Qwen3MoEModelProvider dtype configuration."""
provider = Qwen3MoEModelProvider(
Expand Down
35 changes: 35 additions & 0 deletions tests/unit_tests/training/test_model_load_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def test_load_megatron_model_resets_defaults(self, mock_load_model_config, mock_
cfg.sequence_parallel = True
cfg.virtual_pipeline_model_parallel_size = 2
cfg.hierarchical_context_parallel_sizes = [2, 2]
cfg.moe_token_dispatcher_type = "flex"

mock_load_model_config.return_value = (cfg, None)
sentinel = object()
Expand All @@ -547,6 +548,7 @@ def test_load_megatron_model_resets_defaults(self, mock_load_model_config, mock_
assert cfg.sequence_parallel is False
assert cfg.virtual_pipeline_model_parallel_size is None
assert cfg.hierarchical_context_parallel_sizes is None
assert cfg.moe_token_dispatcher_type == "alltoall"

@patch("megatron.bridge.training.model_load_save.build_and_load_model")
@patch("megatron.bridge.training.model_load_save.load_model_config")
Expand Down Expand Up @@ -580,6 +582,39 @@ def test_load_megatron_model_applies_overrides(self, mock_load_model_config, moc
assert cfg.sequence_parallel is True
assert cfg.virtual_pipeline_model_parallel_size == 4

@patch("megatron.bridge.training.model_load_save.build_and_load_model")
@patch("megatron.bridge.training.model_load_save.load_model_config")
def test_load_megatron_model_keeps_flex_when_overrides_restore_parallelism(
self, mock_load_model_config, mock_build_and_load
):
"""Verify flex dispatcher is kept when mp_overrides set TP*EP > 1."""
cfg = Mock()
cfg.tensor_model_parallel_size = 2
cfg.pipeline_model_parallel_size = 1
cfg.context_parallel_size = 1
cfg.expert_model_parallel_size = 8
cfg.expert_tensor_parallel_size = 1
cfg.sequence_parallel = False
cfg.virtual_pipeline_model_parallel_size = None
cfg.hierarchical_context_parallel_sizes = None
cfg.moe_token_dispatcher_type = "flex"

mock_load_model_config.return_value = (cfg, None)
sentinel = object()
mock_build_and_load.return_value = sentinel

overrides = {
"tensor_model_parallel_size": 2,
"expert_model_parallel_size": 8,
}

result = load_megatron_model("/ckpt", mp_overrides=overrides)

assert result is sentinel
assert cfg.tensor_model_parallel_size == 2
assert cfg.expert_model_parallel_size == 8
assert cfg.moe_token_dispatcher_type == "flex"


class TestSaveMegatronModel:
"""Test save_megatron_model function.
Expand Down