Skip to content

[model, training] MoE checkpoint export fails with YaRN RoPE and flex dispatcher #2640

@shanecmoran

Description

@shanecmoran

Description

Exporting MoE checkpoints (e.g. Qwen3-30B-A3B) trained with YaRN RoPE scaling and the flex token dispatcher to HuggingFace format via load_megatron_model fails with two independent bugs.

Bug 1: YaRN RoPE fields missing from Qwen3MoEModelProvider

Qwen3MoEModelProvider has position_embedding_type: str = "rope" but no yarn_* dataclass fields. When a checkpoint's run_config.yaml contains YaRN parameters, OmegaConf silently drops them during instantiate() because the dataclass has no matching fields.

Ministral3ModelProvider already declares all yarn fields — the same fields need to be added to Qwen3MoEModelProvider.

Current Qwen3MoEModelProvider (missing yarn fields):

# src/megatron/bridge/models/qwen/qwen_provider.py
@dataclass
class Qwen3MoEModelProvider(GPTModelProvider):
    ...
    position_embedding_type: str = "rope"
    # No yarn_* fields — OmegaConf silently drops them from run_config.yaml
    autocast_dtype: torch.dtype = torch.bfloat16
    ...

Existing Ministral3ModelProvider (has all yarn fields):

# src/megatron/bridge/models/ministral3/ministral3_provider.py
@dataclass
class Ministral3ModelProvider(GPTModelProvider):
    ...
    yarn_rotary_scaling_factor: float = 16.0
    yarn_original_max_position_embeddings: int = 16384
    yarn_beta_fast: float = 32.0
    yarn_beta_slow: float = 1.0
    yarn_correction_range_round_to_int: bool = False
    yarn_mscale: Optional[float] = 1.0
    yarn_mscale_all_dim: Optional[float] = 1.0
    ...

Bug 2: flex dispatcher assertion at TP=1, EP=1

load_megatron_model resets tensor_model_parallel_size=1 and expert_model_parallel_size=1 for single-GPU export, but does not adjust moe_token_dispatcher_type. The flex dispatcher in Megatron-Core (MoEFlexTokenDispatcher.__init__) has an explicit assertion:

# megatron/core/transformer/moe/token_dispatcher.py (Megatron-Core)
assert self.tp_size * self.ep_size > 1, "Flex token dispatcher requires TPxEP > 1"

Any checkpoint trained with moe_token_dispatcher_type="flex" — regardless of the original TP/EP values used during training — will hit this assertion when exported via load_megatron_model, because it always forces TP=1 and EP=1.

Current load_megatron_model (resets parallelism but not dispatcher):

# src/megatron/bridge/training/model_load_save.py
def load_megatron_model(...):
    model_cfg, mlm_args = load_model_config(checkpoint_path)
    model_cfg.tensor_model_parallel_size = 1       # forced to 1
    model_cfg.pipeline_model_parallel_size = 1
    ...
    model_cfg.expert_model_parallel_size = 1        # forced to 1
    model_cfg.expert_tensor_parallel_size = 1       # forced to 1
    ...
    model_cfg.hierarchical_context_parallel_sizes = None
    # moe_token_dispatcher_type is NOT reset — "flex" is carried through from
    # the training config, triggering the MCore assertion above
    if use_cpu_init:
        ...

Steps to Reproduce

  1. Fine-tune Qwen3-30B-A3B with YaRN RoPE scaling and moe_token_dispatcher_type="flex" (any TP/EP combination)
  2. Attempt to export the checkpoint: load_megatron_model(checkpoint_path, return_state_dict=True, use_cpu_init=True)
  3. Bug 1: YaRN parameters are silently dropped, producing incorrect RoPE embeddings
  4. Bug 2: flex dispatcher assertion fails because load_megatron_model forces TP=1, EP=1 but leaves the dispatcher type unchanged

Expected Behavior

load_megatron_model should successfully export MoE checkpoints that were trained with YaRN RoPE and/or the flex dispatcher.

Environment

  • Megatron-Bridge: main branch
  • Model: Qwen3-30B-A3B (or any Qwen3 MoE variant with YaRN + flex)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions