-
Notifications
You must be signed in to change notification settings - Fork 199
Description
Describe the bug
I trained a Nemotron3 model and tried to export the Megatron checkpoint following the instructions in the official documentation, but the export failed.
From the logs below, it seems that model_cfg.moe_token_dispatcher_type = "flex" (from DeepEP) is causing the issue.
Is there a way to override this setting during export and temporarily change flex to alltoall only for the export process?
[rank0]: During handling of the above exception, another exception occurred:
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/jaemin/Megatron-Bridge/examples/conversion/convert_checkpoints.py", line 273, in <module>
[rank0]: sys.exit(main())
[rank0]: ^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/examples/conversion/convert_checkpoints.py", line 257, in main
[rank0]: export_megatron_to_hf(
[rank0]: File "/home/jaemin/Megatron-Bridge/examples/conversion/convert_checkpoints.py", line 183, in export_megatron_to_hf
[rank0]: bridge.export_ckpt(
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/conversion/auto_bridge.py", line 806, in export_ckpt
[rank0]: megatron_model = self.load_megatron_model(megatron_path, wrap_with_ddp=False)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/conversion/auto_bridge.py", line 679, in load_megatron_model
[rank0]: model = load_megatron_model(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 400, in load_megatron_model
[rank0]: return build_and_load_model(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 338, in build_and_load_model
[rank0]: return _load_checkpoint()
[rank0]: ^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 316, in _load_checkpoint
[rank0]: model = _call_model_provider(model_cfg)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/training/model_load_save.py", line 284, in _call_model_provider
[rank0]: return model_cfg.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=use_cpu_init)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 196, in provide_distributed_model
[rank0]: model = get_model(
[rank0]: ^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 546, in get_model
[rank0]: model = _create_model(model_provider, model_type, pg_collection=pg_collection)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/model_provider.py", line 653, in _create_model
[rank0]: model = model_provider.provide(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/src/megatron/bridge/models/mamba/mamba_provider.py", line 168, in provide
[rank0]: return MCoreMambaModel(
[rank0]: ^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/models/mamba/mamba_model.py", line 125, in __init__
[rank0]: self.decoder = build_module(
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 126, in build_module
[rank0]: raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 119, in build_module
[rank0]: return module(
[rank0]: ^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/ssm/mamba_block.py", line 151, in __init__
[rank0]: layer = build_module(
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 126, in build_module
[rank0]: raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 119, in build_module
[rank0]: return module(
[rank0]: ^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 1244, in __init__
[rank0]: super().__init__(*args, **kwargs)
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 370, in __init__
[rank0]: self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 126, in build_module
[rank0]: raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/spec_utils.py", line 119, in build_module
[rank0]: return module(
[rank0]: ^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/moe/moe_layer.py", line 211, in __init__
[rank0]: self.token_dispatcher = MoEFlexTokenDispatcher(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/jaemin/Megatron-Bridge/3rdparty/Megatron-LM/megatron/core/transformer/moe/token_dispatcher.py", line 1346, in __init__
[rank0]: assert self.tp_size * self.ep_size > 1, "Flex token dispatcher requires TPxEP > 1"
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: Flex token dispatcher requires TPxEP > 1 when instantiating MoELayer when instantiating MoETransformerLayer when instantiating MambaStack
Steps/Code to reproduce bug
commit & branch: 4a8eac1550798373bfcaa721cebf4e1fbaf49710 in r0.3.0
Megatron checkpoint trained with DeepEP and model_cfg.moe_token_dispatcher_type = "flex"
python examples/conversion/convert_checkpoints.py export \
--hf-model $HF_MODEL_ID \
--megatron-path /path/to/trained/megatron/ckpt \
--hf-path /path/to/output/hf/ckpt
A helpful guide on on how to craft a minimal bug report http://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports.
Expected behavior
Override moe_token_dispatcher_type during export and temporarily change flex to alltoall only for the export process.
I temporarily added this change and it seems to work. However, I’m not sure if this could cause other issues.
Additional context
If this approach looks reasonable, I’d be happy to open a PR.