Skip to content

Exporting Megatron checkpoint to HF checkpoint fails #2655

@jaeminh

Description

@jaeminh

Describe the bug

I trained a Nemotron3 model and tried to export the Megatron checkpoint following the instructions in the official documentation, but the export failed.
From the logs below, it seems that model_cfg.moe_token_dispatcher_type = "flex" (from DeepEP) is causing the issue.
Is there a way to override this setting during export and temporarily change flex to alltoall only for the export process?

[rank0]: During handling of the above exception, another exception occurred:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/jaemin/Megatron-Bridge/examples/conversion/convert_checkpoints.py", line 273, in <module>
[rank0]:     sys.exit(main())
[rank0]:              ^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/examples/conversion/convert_checkpoints.py", line 257, in main
[rank0]:     export_megatron_to_hf(
[rank0]:   File "/home/jaemin/Megatron-Bridge/examples/conversion/convert_checkpoints.py", line 183, in export_megatron_to_hf
[rank0]:     bridge.export_ckpt(
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/conversion/auto_bridge.py", line 806, in export_ckpt
[rank0]:     megatron_model = self.load_megatron_model(megatron_path, wrap_with_ddp=False)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/conversion/auto_bridge.py", line 679, in load_megatron_model
[rank0]:     model = load_megatron_model(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 400, in load_megatron_model
[rank0]:     return build_and_load_model(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 338, in build_and_load_model
[rank0]:     return _load_checkpoint()
[rank0]:            ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 316, in _load_checkpoint
[rank0]:     model = _call_model_provider(model_cfg)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 284, in _call_model_provider
[rank0]:     return model_cfg.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=use_cpu_init)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 196, in provide_distributed_model
[rank0]:     model = get_model(
[rank0]:             ^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 546, in get_model
[rank0]:     model = _create_model(model_provider, model_type, pg_collection=pg_collection)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 653, in _create_model
[rank0]:     model = model_provider.provide(
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/mamba/mamba_provider.py", line 168, in provide
[rank0]:     return MCoreMambaModel(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/models/mamba/mamba_model.py", line 125, in __init__
[rank0]:     self.decoder = build_module(
[rank0]:                    ^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 126, in build_module
[rank0]:     raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 119, in build_module
[rank0]:     return module(
[rank0]:            ^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/ssm/mamba_block.py", line 151, in __init__
[rank0]:     layer = build_module(
[rank0]:             ^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 126, in build_module
[rank0]:     raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 119, in build_module
[rank0]:     return module(
[rank0]:            ^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 1244, in __init__
[rank0]:     super().__init__(*args, **kwargs)
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 370, in __init__
[rank0]:     self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 126, in build_module
[rank0]:     raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 119, in build_module
[rank0]:     return module(
[rank0]:            ^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 211, in __init__
[rank0]:     self.token_dispatcher = MoEFlexTokenDispatcher(
[rank0]:                             ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py", line 1346, in __init__
[rank0]:     assert self.tp_size * self.ep_size > 1, "Flex token dispatcher requires TPxEP > 1"
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: Flex token dispatcher requires TPxEP > 1 when instantiating MoELayer when instantiating MoETransformerLayer when instantiating MambaStack

Steps/Code to reproduce bug

commit & branch: 4a8eac1550798373bfcaa721cebf4e1fbaf49710 in r0.3.0
Megatron checkpoint trained with DeepEP and model_cfg.moe_token_dispatcher_type = "flex"

python examples/conversion/convert_checkpoints.py export  \
--hf-model $HF_MODEL_ID  \
--megatron-path /path/to/trained/megatron/ckpt \
--hf-path /path/to/output/hf/ckpt

A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.

Expected behavior

Override moe_token_dispatcher_type during export and temporarily change flex to alltoall only for the export process.
I temporarily added this change and it seems to work. However, I’m not sure if this could cause other issues.

Additional context

If this approach looks reasonable, I’d be happy to open a PR.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions