[model, training] fix: MoE checkpoint export with YaRN RoPE and flex dispatcher#2641
[model, training] fix: MoE checkpoint export with YaRN RoPE and flex dispatcher#2641shanecmoran wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds YaRN RoPE configuration fields to Qwen3MoEModelProvider, implements logic to coerce moe_token_dispatcher_type from "flex" to "alltoall" during model loading, and adds corresponding unit tests to validate these configurations and their default values. Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~15 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
20236c1 to
2aa6bc5
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py (1)
158-197: Tag the new tests with a pytest category marker.Please add a unit-test marker (for example
@pytest.mark.unit) on the newly added test methods.As per coding guidelines, "Use 'pytest.mark' to categorize tests (unit, integration, system)".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py` around lines 158 - 197, Add the pytest unit marker to both new tests by decorating test_qwen3_moe_model_provider_with_yarn_rope and test_qwen3_moe_model_provider_yarn_defaults_none with `@pytest.mark.unit`; ensure pytest is imported at the top of the file (add "import pytest" if missing) so the markers resolve correctly.tests/unit_tests/training/test_model_load_save.py (1)
519-552: Add an override-path regression test forflexdispatcher behavior.This test only covers the default reset path. Please add a companion case where
mp_overridesmakes TP×EP > 1 and verifymoe_token_dispatcher_typebehavior stays consistent with intended semantics.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/training/test_model_load_save.py` around lines 519 - 552, Add a companion unit test to exercise the override path of load_megatron_model: copy the existing test_load_megatron_model_resets_defaults structure but pass mp_overrides that set tensor_model_parallel_size and expert_model_parallel_size such that TP×EP > 1 (e.g., tensor_model_parallel_size=2 and expert_model_parallel_size=2) and set the loaded cfg.moe_token_dispatcher_type to "flex"; call load_megatron_model with those mp_overrides (mocking load_model_config to return the cfg and mock_build_and_load to return a sentinel) and assert the function still returns the sentinel and that cfg.moe_token_dispatcher_type is set/normalized to the expected final value according to the intended semantics (e.g., remains "flex" or becomes "alltoall" per spec) while also verifying other parallel-size resets/overrides were applied as appropriate; reference load_megatron_model, mp_overrides, mock_load_model_config, and mock_build_and_load to locate where to wire the test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/training/model_load_save.py`:
- Around line 389-390: The current unconditional coercion of
model_cfg.moe_token_dispatcher_type from "flex" to "alltoall" runs before
mp_overrides are applied and can incorrectly override valid user overrides; move
the conditional replacement so it happens after mp_overrides processing (the
block that applies tensor/row/col/ep/tp overrides) so that if overrides set
TP/EP > 1 the dispatcher remains as specified; update both occurrences that
coerce "flex" (the existing standalone if using getattr and the similar logic
around the mp_overrides application) to run only after the override application
code that modifies model_cfg.{tp,ep,...} and ensure the check uses the final
model_cfg values.
In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py`:
- Around line 158-180: The test test_qwen3_moe_model_provider_with_yarn_rope is
meant to exercise YaRN mode but currently constructs Qwen3MoEModelProvider with
position_embedding_type="rope"; change the instantiation to use
position_embedding_type="yarn" so the Qwen3MoEModelProvider(...) call actually
configures YaRN-specific fields (yarn_rotary_scaling_factor,
yarn_original_max_position_embeddings, etc.) and the subsequent asserts validate
the YaRN behavior.
---
Nitpick comments:
In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py`:
- Around line 158-197: Add the pytest unit marker to both new tests by
decorating test_qwen3_moe_model_provider_with_yarn_rope and
test_qwen3_moe_model_provider_yarn_defaults_none with `@pytest.mark.unit`; ensure
pytest is imported at the top of the file (add "import pytest" if missing) so
the markers resolve correctly.
In `@tests/unit_tests/training/test_model_load_save.py`:
- Around line 519-552: Add a companion unit test to exercise the override path
of load_megatron_model: copy the existing
test_load_megatron_model_resets_defaults structure but pass mp_overrides that
set tensor_model_parallel_size and expert_model_parallel_size such that TP×EP >
1 (e.g., tensor_model_parallel_size=2 and expert_model_parallel_size=2) and set
the loaded cfg.moe_token_dispatcher_type to "flex"; call load_megatron_model
with those mp_overrides (mocking load_model_config to return the cfg and
mock_build_and_load to return a sentinel) and assert the function still returns
the sentinel and that cfg.moe_token_dispatcher_type is set/normalized to the
expected final value according to the intended semantics (e.g., remains "flex"
or becomes "alltoall" per spec) while also verifying other parallel-size
resets/overrides were applied as appropriate; reference load_megatron_model,
mp_overrides, mock_load_model_config, and mock_build_and_load to locate where to
wire the test.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e60e10cc-089f-4459-a9d0-5b7f22eb0e75
📒 Files selected for processing (4)
src/megatron/bridge/models/qwen/qwen_provider.pysrc/megatron/bridge/training/model_load_save.pytests/unit_tests/models/qwen/test_qwen3_moe_model_provider.pytests/unit_tests/training/test_model_load_save.py
| if getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex": | ||
| model_cfg.moe_token_dispatcher_type = "alltoall" |
There was a problem hiding this comment.
Conditionally fallback flex only after overrides are applied.
Line 389 currently coerces "flex" before mp_overrides (Lines 396-399). This can incorrectly downgrade valid override scenarios where TP/EP are set back above 1.
Proposed fix
- if getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex":
- model_cfg.moe_token_dispatcher_type = "alltoall"
if use_cpu_init:
model_cfg.fp8 = None
model_cfg.fp8_param = False
# Apply model-parallel overrides if provided
if mp_overrides:
for key, value in mp_overrides.items():
if hasattr(model_cfg, key) and value is not None:
setattr(model_cfg, key, value)
+
+ if (
+ getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex"
+ and getattr(model_cfg, "tensor_model_parallel_size", 1) * getattr(model_cfg, "expert_model_parallel_size", 1)
+ <= 1
+ ):
+ logger.warning("Falling back moe_token_dispatcher_type from 'flex' to 'alltoall' for TP*EP <= 1.")
+ model_cfg.moe_token_dispatcher_type = "alltoall"Also applies to: 396-399
🧰 Tools
🪛 Ruff (0.15.2)
[error] 390-390: Possible hardcoded password assigned to: "moe_token_dispatcher_type"
(S105)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/training/model_load_save.py` around lines 389 - 390, The
current unconditional coercion of model_cfg.moe_token_dispatcher_type from
"flex" to "alltoall" runs before mp_overrides are applied and can incorrectly
override valid user overrides; move the conditional replacement so it happens
after mp_overrides processing (the block that applies tensor/row/col/ep/tp
overrides) so that if overrides set TP/EP > 1 the dispatcher remains as
specified; update both occurrences that coerce "flex" (the existing standalone
if using getattr and the similar logic around the mp_overrides application) to
run only after the override application code that modifies model_cfg.{tp,ep,...}
and ensure the check uses the final model_cfg values.
| def test_qwen3_moe_model_provider_with_yarn_rope(self): | ||
| """Test Qwen3MoEModelProvider with YaRN RoPE configuration.""" | ||
| provider = Qwen3MoEModelProvider( | ||
| num_layers=32, | ||
| hidden_size=4096, | ||
| num_attention_heads=32, | ||
| position_embedding_type="rope", | ||
| yarn_rotary_scaling_factor=4.0, | ||
| yarn_original_max_position_embeddings=32768, | ||
| yarn_beta_fast=32.0, | ||
| yarn_beta_slow=1.0, | ||
| yarn_mscale=1.0, | ||
| yarn_mscale_all_dim=1.0, | ||
| yarn_correction_range_round_to_int=True, | ||
| ) | ||
|
|
||
| assert provider.yarn_rotary_scaling_factor == 4.0 | ||
| assert provider.yarn_original_max_position_embeddings == 32768 | ||
| assert provider.yarn_beta_fast == 32.0 | ||
| assert provider.yarn_beta_slow == 1.0 | ||
| assert provider.yarn_mscale == 1.0 | ||
| assert provider.yarn_mscale_all_dim == 1.0 | ||
| assert provider.yarn_correction_range_round_to_int is True |
There was a problem hiding this comment.
Use position_embedding_type="yarn" in the YaRN test case.
Line 164 sets "rope", so this test does not actually exercise the YaRN mode it is named for.
Proposed fix
- position_embedding_type="rope",
+ position_embedding_type="yarn",
yarn_rotary_scaling_factor=4.0,
@@
assert provider.yarn_rotary_scaling_factor == 4.0
+ assert provider.position_embedding_type == "yarn"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py` around lines
158 - 180, The test test_qwen3_moe_model_provider_with_yarn_rope is meant to
exercise YaRN mode but currently constructs Qwen3MoEModelProvider with
position_embedding_type="rope"; change the instantiation to use
position_embedding_type="yarn" so the Qwen3MoEModelProvider(...) call actually
configures YaRN-specific fields (yarn_rotary_scaling_factor,
yarn_original_max_position_embeddings, etc.) and the subsequent asserts validate
the YaRN behavior.
2aa6bc5 to
c64232f
Compare
…dispatcher Add YaRN RoPE configuration fields to Qwen3MoEModelProvider so they are preserved in run_config.yaml during checkpoint saving. Move the flex -> alltoall dispatcher coercion in load_megatron_model to run after mp_overrides are applied, and only when TP*EP <= 1. This avoids incorrectly downgrading the dispatcher when overrides restore multi-GPU parallelism. Signed-off-by: Shane Moran <shane.moran@shopify.com>
c64232f to
eb696b8
Compare
What does this PR do ?
Fix two independent bugs that prevent exporting MoE checkpoints (e.g. Qwen3-30B-A3B) trained with YaRN RoPE scaling and the
flextoken dispatcher to HuggingFace format viaload_megatron_model.Changelog
yarn_*dataclass fields toQwen3MoEModelProvider(Nonedefaults) so YaRN RoPE parameters fromrun_config.yamlare preserved duringinstantiate()— matches existing pattern inMinistral3ModelProviderflexdispatcher toalltoallinload_megatron_modelwhen parallelism is reset to TP=1/EP=1, avoiding the flex assertion that requiresTP*EP > 1Qwen3MoEModelProvidertest_load_megatron_model_resets_defaultsto verify flex→alltoall fallbackGitHub Actions CI
/ok to test— external contributor, CI approval needed.Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Release Notes
New Features
Improvements
Tests