From fb81b6d669b59c9e114039aa358a79a8e61bcb76 Mon Sep 17 00:00:00 2001 From: Shane Moran Date: Wed, 4 Mar 2026 07:25:55 -0500 Subject: [PATCH] [model, training] fix: MoE checkpoint export with YaRN RoPE and flex dispatcher Add YaRN RoPE configuration fields to Qwen3MoEModelProvider so they are preserved in run_config.yaml during checkpoint saving. Move the flex -> alltoall dispatcher coercion in load_megatron_model to run after mp_overrides are applied, and only when TP*EP <= 1. This avoids incorrectly downgrading the dispatcher when overrides restore multi-GPU parallelism. Signed-off-by: Shane Moran --- .../bridge/models/qwen/qwen_provider.py | 7 ++++ .../bridge/training/model_load_save.py | 8 ++++ .../qwen/test_qwen3_moe_model_provider.py | 40 +++++++++++++++++++ .../training/test_model_load_save.py | 35 ++++++++++++++++ 4 files changed, 90 insertions(+) diff --git a/src/megatron/bridge/models/qwen/qwen_provider.py b/src/megatron/bridge/models/qwen/qwen_provider.py index 775b976523..60ba625a62 100644 --- a/src/megatron/bridge/models/qwen/qwen_provider.py +++ b/src/megatron/bridge/models/qwen/qwen_provider.py @@ -382,6 +382,13 @@ class Qwen3MoEModelProvider(GPTModelProvider): layernorm_epsilon: float = 1e-6 rotary_base: float = 1000000.0 position_embedding_type: str = "rope" + yarn_rotary_scaling_factor: float | None = None + yarn_original_max_position_embeddings: int | None = None + yarn_beta_fast: float | None = None + yarn_beta_slow: float | None = None + yarn_mscale: float | None = None + yarn_mscale_all_dim: float | None = None + yarn_correction_range_round_to_int: bool = False autocast_dtype: torch.dtype = torch.bfloat16 params_dtype: torch.dtype = torch.bfloat16 bf16: bool = True diff --git a/src/megatron/bridge/training/model_load_save.py b/src/megatron/bridge/training/model_load_save.py index 3567cc62a8..73b0afdf41 100644 --- a/src/megatron/bridge/training/model_load_save.py +++ b/src/megatron/bridge/training/model_load_save.py @@ -396,6 +396,14 @@ def load_megatron_model( if hasattr(model_cfg, key) and value is not None: setattr(model_cfg, key, value) + # flex dispatcher requires TP*EP > 1; fall back to alltoall for single-GPU export + if ( + getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex" + and getattr(model_cfg, "tensor_model_parallel_size", 1) * getattr(model_cfg, "expert_model_parallel_size", 1) + <= 1 + ): + model_cfg.moe_token_dispatcher_type = "alltoall" + return build_and_load_model( checkpoint_path, model_cfg, model_type, mlm_args, return_state_dict, use_cpu_init, skip_temp_dist_context ) diff --git a/tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py b/tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py index 10fe3c9614..a649bb4f58 100644 --- a/tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py +++ b/tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py @@ -155,6 +155,46 @@ def test_qwen3_moe_model_provider_qk_layernorm(self): # Qwen3 MoE uses QK layernorm unlike Qwen2 assert provider.qk_layernorm is True + def test_qwen3_moe_model_provider_with_yarn_rope(self): + """Test Qwen3MoEModelProvider with YaRN RoPE configuration.""" + provider = Qwen3MoEModelProvider( + num_layers=32, + hidden_size=4096, + num_attention_heads=32, + position_embedding_type="yarn", + yarn_rotary_scaling_factor=4.0, + yarn_original_max_position_embeddings=32768, + yarn_beta_fast=32.0, + yarn_beta_slow=1.0, + yarn_mscale=1.0, + yarn_mscale_all_dim=1.0, + yarn_correction_range_round_to_int=True, + ) + + assert provider.yarn_rotary_scaling_factor == 4.0 + assert provider.yarn_original_max_position_embeddings == 32768 + assert provider.yarn_beta_fast == 32.0 + assert provider.yarn_beta_slow == 1.0 + assert provider.yarn_mscale == 1.0 + assert provider.yarn_mscale_all_dim == 1.0 + assert provider.yarn_correction_range_round_to_int is True + + def test_qwen3_moe_model_provider_yarn_defaults_none(self): + """Test Qwen3MoEModelProvider YaRN fields default to None/False.""" + provider = Qwen3MoEModelProvider( + num_layers=32, + hidden_size=4096, + num_attention_heads=32, + ) + + assert provider.yarn_rotary_scaling_factor is None + assert provider.yarn_original_max_position_embeddings is None + assert provider.yarn_beta_fast is None + assert provider.yarn_beta_slow is None + assert provider.yarn_mscale is None + assert provider.yarn_mscale_all_dim is None + assert provider.yarn_correction_range_round_to_int is False + def test_qwen3_moe_model_provider_dtype_configuration(self): """Test Qwen3MoEModelProvider dtype configuration.""" provider = Qwen3MoEModelProvider( diff --git a/tests/unit_tests/training/test_model_load_save.py b/tests/unit_tests/training/test_model_load_save.py index a779e83281..13f930c671 100644 --- a/tests/unit_tests/training/test_model_load_save.py +++ b/tests/unit_tests/training/test_model_load_save.py @@ -528,6 +528,7 @@ def test_load_megatron_model_resets_defaults(self, mock_load_model_config, mock_ cfg.sequence_parallel = True cfg.virtual_pipeline_model_parallel_size = 2 cfg.hierarchical_context_parallel_sizes = [2, 2] + cfg.moe_token_dispatcher_type = "flex" mock_load_model_config.return_value = (cfg, None) sentinel = object() @@ -547,6 +548,7 @@ def test_load_megatron_model_resets_defaults(self, mock_load_model_config, mock_ assert cfg.sequence_parallel is False assert cfg.virtual_pipeline_model_parallel_size is None assert cfg.hierarchical_context_parallel_sizes is None + assert cfg.moe_token_dispatcher_type == "alltoall" @patch("megatron.bridge.training.model_load_save.build_and_load_model") @patch("megatron.bridge.training.model_load_save.load_model_config") @@ -580,6 +582,39 @@ def test_load_megatron_model_applies_overrides(self, mock_load_model_config, moc assert cfg.sequence_parallel is True assert cfg.virtual_pipeline_model_parallel_size == 4 + @patch("megatron.bridge.training.model_load_save.build_and_load_model") + @patch("megatron.bridge.training.model_load_save.load_model_config") + def test_load_megatron_model_keeps_flex_when_overrides_restore_parallelism( + self, mock_load_model_config, mock_build_and_load + ): + """Verify flex dispatcher is kept when mp_overrides set TP*EP > 1.""" + cfg = Mock() + cfg.tensor_model_parallel_size = 2 + cfg.pipeline_model_parallel_size = 1 + cfg.context_parallel_size = 1 + cfg.expert_model_parallel_size = 8 + cfg.expert_tensor_parallel_size = 1 + cfg.sequence_parallel = False + cfg.virtual_pipeline_model_parallel_size = None + cfg.hierarchical_context_parallel_sizes = None + cfg.moe_token_dispatcher_type = "flex" + + mock_load_model_config.return_value = (cfg, None) + sentinel = object() + mock_build_and_load.return_value = sentinel + + overrides = { + "tensor_model_parallel_size": 2, + "expert_model_parallel_size": 8, + } + + result = load_megatron_model("/ckpt", mp_overrides=overrides) + + assert result is sentinel + assert cfg.tensor_model_parallel_size == 2 + assert cfg.expert_model_parallel_size == 8 + assert cfg.moe_token_dispatcher_type == "flex" + class TestSaveMegatronModel: """Test save_megatron_model function.