Skip to content

[model, training] fix: MoE checkpoint export with YaRN RoPE and flex dispatcher#2641

Open
shanecmoran wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
shanecmoran:fix/moe-checkpoint-export
Open

[model, training] fix: MoE checkpoint export with YaRN RoPE and flex dispatcher#2641
shanecmoran wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
shanecmoran:fix/moe-checkpoint-export

Conversation

@shanecmoran
Copy link

@shanecmoran shanecmoran commented Mar 4, 2026

What does this PR do ?

Fix two independent bugs that prevent exporting MoE checkpoints (e.g. Qwen3-30B-A3B) trained with YaRN RoPE scaling and the flex token dispatcher to HuggingFace format via load_megatron_model.

Changelog

  • Add 7 yarn_* dataclass fields to Qwen3MoEModelProvider (None defaults) so YaRN RoPE parameters from run_config.yaml are preserved during instantiate() — matches existing pattern in Ministral3ModelProvider
  • Fall back flex dispatcher to alltoall in load_megatron_model when parallelism is reset to TP=1/EP=1, avoiding the flex assertion that requires TP*EP > 1
  • Add tests for YaRN field defaults and custom values on Qwen3MoEModelProvider
  • Extend test_load_megatron_model_resets_defaults to verify flex→alltoall fallback

GitHub Actions CI

/ok to test — external contributor, CI approval needed.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • N/A — no new dependencies

Additional Information

Summary by CodeRabbit

Release Notes

  • New Features

    • Added YaRN RoPE (rotary position embedding) configuration options for Qwen3 MoE models, enabling fine-tuning of scaling factors, position embeddings, and correction parameters.
  • Improvements

    • Enhanced model loading to ensure compatibility with flexible dispatcher type configurations.
  • Tests

    • Added test coverage validating YaRN RoPE configuration fields and default behaviors.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@shanecmoran shanecmoran marked this pull request as ready for review March 4, 2026 17:48
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

📝 Walkthrough

Walkthrough

This PR adds YaRN RoPE configuration fields to Qwen3MoEModelProvider, implements logic to coerce moe_token_dispatcher_type from "flex" to "alltoall" during model loading, and adds corresponding unit tests to validate these configurations and their default values.

Changes

Cohort / File(s) Summary
Qwen3MoE YaRN RoPE Configuration
src/megatron/bridge/models/qwen/qwen_provider.py
Added seven new optional configuration fields to support YaRN RoPE: yarn_rotary_scaling_factor, yarn_original_max_position_embeddings, yarn_beta_fast, yarn_beta_slow, yarn_mscale, yarn_mscale_all_dim, and yarn_correction_range_round_to_int.
Model Loading Behavior
src/megatron/bridge/training/model_load_save.py
Added logic to coerce moe_token_dispatcher_type from "flex" to "alltoall" when loading model configuration.
Qwen3MoE Unit Tests
tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py
Added two new tests: test_qwen3_moe_model_provider_with_yarn_rope validates YaRN RoPE fields and their values; test_qwen3_moe_model_provider_yarn_defaults_none validates default None/False values for YaRN fields.
Model Load/Save Unit Tests
tests/unit_tests/training/test_model_load_save.py
Updated test_load_megatron_model_resets_defaults to verify that moe_token_dispatcher_type of "flex" is coerced to "alltoall" during configuration reset.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~15 minutes

Possibly related PRs

Suggested reviewers

  • erhoo82
  • yaoyu-33
  • malay-nagda
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR adds 40 lines of unit tests but CI has not been executed (marked 'External contributor — /ok to test' approval needed), and review comments identify implementation issues with flex dispatcher ordering and incorrect YaRN test configuration. Execute the full test suite and document results in PR description, address identified implementation issues, and verify fixes work correctly before merging.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main changes: fixing MoE checkpoint export by adding YaRN RoPE support and handling the flex dispatcher.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@shanecmoran shanecmoran force-pushed the fix/moe-checkpoint-export branch from 20236c1 to 2aa6bc5 Compare March 4, 2026 19:00
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py (1)

158-197: Tag the new tests with a pytest category marker.

Please add a unit-test marker (for example @pytest.mark.unit) on the newly added test methods.

As per coding guidelines, "Use 'pytest.mark' to categorize tests (unit, integration, system)".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py` around lines
158 - 197, Add the pytest unit marker to both new tests by decorating
test_qwen3_moe_model_provider_with_yarn_rope and
test_qwen3_moe_model_provider_yarn_defaults_none with `@pytest.mark.unit`; ensure
pytest is imported at the top of the file (add "import pytest" if missing) so
the markers resolve correctly.
tests/unit_tests/training/test_model_load_save.py (1)

519-552: Add an override-path regression test for flex dispatcher behavior.

This test only covers the default reset path. Please add a companion case where mp_overrides makes TP×EP > 1 and verify moe_token_dispatcher_type behavior stays consistent with intended semantics.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/training/test_model_load_save.py` around lines 519 - 552,
Add a companion unit test to exercise the override path of load_megatron_model:
copy the existing test_load_megatron_model_resets_defaults structure but pass
mp_overrides that set tensor_model_parallel_size and expert_model_parallel_size
such that TP×EP > 1 (e.g., tensor_model_parallel_size=2 and
expert_model_parallel_size=2) and set the loaded cfg.moe_token_dispatcher_type
to "flex"; call load_megatron_model with those mp_overrides (mocking
load_model_config to return the cfg and mock_build_and_load to return a
sentinel) and assert the function still returns the sentinel and that
cfg.moe_token_dispatcher_type is set/normalized to the expected final value
according to the intended semantics (e.g., remains "flex" or becomes "alltoall"
per spec) while also verifying other parallel-size resets/overrides were applied
as appropriate; reference load_megatron_model, mp_overrides,
mock_load_model_config, and mock_build_and_load to locate where to wire the
test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/megatron/bridge/training/model_load_save.py`:
- Around line 389-390: The current unconditional coercion of
model_cfg.moe_token_dispatcher_type from "flex" to "alltoall" runs before
mp_overrides are applied and can incorrectly override valid user overrides; move
the conditional replacement so it happens after mp_overrides processing (the
block that applies tensor/row/col/ep/tp overrides) so that if overrides set
TP/EP > 1 the dispatcher remains as specified; update both occurrences that
coerce "flex" (the existing standalone if using getattr and the similar logic
around the mp_overrides application) to run only after the override application
code that modifies model_cfg.{tp,ep,...} and ensure the check uses the final
model_cfg values.

In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py`:
- Around line 158-180: The test test_qwen3_moe_model_provider_with_yarn_rope is
meant to exercise YaRN mode but currently constructs Qwen3MoEModelProvider with
position_embedding_type="rope"; change the instantiation to use
position_embedding_type="yarn" so the Qwen3MoEModelProvider(...) call actually
configures YaRN-specific fields (yarn_rotary_scaling_factor,
yarn_original_max_position_embeddings, etc.) and the subsequent asserts validate
the YaRN behavior.

---

Nitpick comments:
In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py`:
- Around line 158-197: Add the pytest unit marker to both new tests by
decorating test_qwen3_moe_model_provider_with_yarn_rope and
test_qwen3_moe_model_provider_yarn_defaults_none with `@pytest.mark.unit`; ensure
pytest is imported at the top of the file (add "import pytest" if missing) so
the markers resolve correctly.

In `@tests/unit_tests/training/test_model_load_save.py`:
- Around line 519-552: Add a companion unit test to exercise the override path
of load_megatron_model: copy the existing
test_load_megatron_model_resets_defaults structure but pass mp_overrides that
set tensor_model_parallel_size and expert_model_parallel_size such that TP×EP >
1 (e.g., tensor_model_parallel_size=2 and expert_model_parallel_size=2) and set
the loaded cfg.moe_token_dispatcher_type to "flex"; call load_megatron_model
with those mp_overrides (mocking load_model_config to return the cfg and
mock_build_and_load to return a sentinel) and assert the function still returns
the sentinel and that cfg.moe_token_dispatcher_type is set/normalized to the
expected final value according to the intended semantics (e.g., remains "flex"
or becomes "alltoall" per spec) while also verifying other parallel-size
resets/overrides were applied as appropriate; reference load_megatron_model,
mp_overrides, mock_load_model_config, and mock_build_and_load to locate where to
wire the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e60e10cc-089f-4459-a9d0-5b7f22eb0e75

📥 Commits

Reviewing files that changed from the base of the PR and between dd0a1d9 and 2aa6bc5.

📒 Files selected for processing (4)
  • src/megatron/bridge/models/qwen/qwen_provider.py
  • src/megatron/bridge/training/model_load_save.py
  • tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py
  • tests/unit_tests/training/test_model_load_save.py

Comment on lines +389 to +390
if getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex":
model_cfg.moe_token_dispatcher_type = "alltoall"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Conditionally fallback flex only after overrides are applied.

Line 389 currently coerces "flex" before mp_overrides (Lines 396-399). This can incorrectly downgrade valid override scenarios where TP/EP are set back above 1.

Proposed fix
-    if getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex":
-        model_cfg.moe_token_dispatcher_type = "alltoall"
     if use_cpu_init:
         model_cfg.fp8 = None
         model_cfg.fp8_param = False

     # Apply model-parallel overrides if provided
     if mp_overrides:
         for key, value in mp_overrides.items():
             if hasattr(model_cfg, key) and value is not None:
                 setattr(model_cfg, key, value)
+
+    if (
+        getattr(model_cfg, "moe_token_dispatcher_type", None) == "flex"
+        and getattr(model_cfg, "tensor_model_parallel_size", 1) * getattr(model_cfg, "expert_model_parallel_size", 1)
+        <= 1
+    ):
+        logger.warning("Falling back moe_token_dispatcher_type from 'flex' to 'alltoall' for TP*EP <= 1.")
+        model_cfg.moe_token_dispatcher_type = "alltoall"

Also applies to: 396-399

🧰 Tools
🪛 Ruff (0.15.2)

[error] 390-390: Possible hardcoded password assigned to: "moe_token_dispatcher_type"

(S105)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/training/model_load_save.py` around lines 389 - 390, The
current unconditional coercion of model_cfg.moe_token_dispatcher_type from
"flex" to "alltoall" runs before mp_overrides are applied and can incorrectly
override valid user overrides; move the conditional replacement so it happens
after mp_overrides processing (the block that applies tensor/row/col/ep/tp
overrides) so that if overrides set TP/EP > 1 the dispatcher remains as
specified; update both occurrences that coerce "flex" (the existing standalone
if using getattr and the similar logic around the mp_overrides application) to
run only after the override application code that modifies model_cfg.{tp,ep,...}
and ensure the check uses the final model_cfg values.

Comment on lines +158 to +180
def test_qwen3_moe_model_provider_with_yarn_rope(self):
"""Test Qwen3MoEModelProvider with YaRN RoPE configuration."""
provider = Qwen3MoEModelProvider(
num_layers=32,
hidden_size=4096,
num_attention_heads=32,
position_embedding_type="rope",
yarn_rotary_scaling_factor=4.0,
yarn_original_max_position_embeddings=32768,
yarn_beta_fast=32.0,
yarn_beta_slow=1.0,
yarn_mscale=1.0,
yarn_mscale_all_dim=1.0,
yarn_correction_range_round_to_int=True,
)

assert provider.yarn_rotary_scaling_factor == 4.0
assert provider.yarn_original_max_position_embeddings == 32768
assert provider.yarn_beta_fast == 32.0
assert provider.yarn_beta_slow == 1.0
assert provider.yarn_mscale == 1.0
assert provider.yarn_mscale_all_dim == 1.0
assert provider.yarn_correction_range_round_to_int is True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use position_embedding_type="yarn" in the YaRN test case.

Line 164 sets "rope", so this test does not actually exercise the YaRN mode it is named for.

Proposed fix
-            position_embedding_type="rope",
+            position_embedding_type="yarn",
             yarn_rotary_scaling_factor=4.0,
@@
         assert provider.yarn_rotary_scaling_factor == 4.0
+        assert provider.position_embedding_type == "yarn"
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/models/qwen/test_qwen3_moe_model_provider.py` around lines
158 - 180, The test test_qwen3_moe_model_provider_with_yarn_rope is meant to
exercise YaRN mode but currently constructs Qwen3MoEModelProvider with
position_embedding_type="rope"; change the instantiation to use
position_embedding_type="yarn" so the Qwen3MoEModelProvider(...) call actually
configures YaRN-specific fields (yarn_rotary_scaling_factor,
yarn_original_max_position_embeddings, etc.) and the subsequent asserts validate
the YaRN behavior.

@shanecmoran shanecmoran force-pushed the fix/moe-checkpoint-export branch from 2aa6bc5 to c64232f Compare March 4, 2026 20:15
…dispatcher

Add YaRN RoPE configuration fields to Qwen3MoEModelProvider so they
are preserved in run_config.yaml during checkpoint saving.

Move the flex -> alltoall dispatcher coercion in load_megatron_model
to run after mp_overrides are applied, and only when TP*EP <= 1.
This avoids incorrectly downgrading the dispatcher when overrides
restore multi-GPU parallelism.

Signed-off-by: Shane Moran <shane.moran@shopify.com>
@shanecmoran shanecmoran force-pushed the fix/moe-checkpoint-export branch from c64232f to eb696b8 Compare March 5, 2026 11:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant