Skip to content

Commit 54ba962

Browse files
committed
Merge branch 'main' into release/1.0
2 parents e4152f4 + c94a2ad commit 54ba962

10 files changed

Lines changed: 507 additions & 32 deletions

File tree

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
<p> -->
88

99
<p align="center">
10-
<b>Providing Megatron-Core model definitions for state-of-the-art large language models</b>
10+
<b>Providing Megatron-Core model definitions for state-of-the-art large models</b>
1111
</p>
1212

1313
<p align="center">
14-
<a href="https://modelscope.cn/home">ModelScope Community Website</a>
14+
<a href="https://modelscope.cn">ModelScope</a>
1515
<br>
1616
<a href="README_zh.md">中文</a> &nbsp | &nbsp English &nbsp
1717
</p>
@@ -21,17 +21,17 @@
2121
<img src="https://img.shields.io/badge/python-3.11-5be.svg">
2222
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
2323
<a href="https://github.com/NVIDIA/Megatron-LM/"><img src="https://img.shields.io/badge/megatron--core-%E2%89%A50.12-76B900.svg"></a>
24-
<a href="https://mcore-bridge.readthedocs.io/en/latest/"><img src="https://img.shields.io/badge/docs-latest-blue.svg"></a>
24+
<!-- <a href="https://mcore-bridge.readthedocs.io/en/latest/"><img src="https://img.shields.io/badge/docs-latest-blue.svg"></a> -->
2525
<a href="https://pypi.org/project/mcore-bridge/"><img src="https://badge.fury.io/py/mcore-bridge.svg"></a>
2626
<a href="https://github.com/modelscope/mcore-bridge/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/mcore-bridge"></a>
2727
<a href="https://pepy.tech/project/mcore-bridge"><img src="https://pepy.tech/badge/mcore-bridge"></a>
2828
<a href="https://github.com/modelscope/mcore-bridge/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
2929
</p>
3030

3131

32-
<p align="center">
32+
<!-- <p align="center">
3333
<a href="https://mcore-bridge.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://mcore-bridge.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
34-
</p>
34+
</p> -->
3535

3636
## 📖 Table of Contents
3737
- [Groups](#-Groups)
@@ -54,7 +54,7 @@ You can contact us and communicate with us by adding our group:
5454
## 📝 Introduction
5555

5656
## 🎉 News
57-
- 🎉 2025.04.01: MCore-Bridge is released! Providing Megatron-Core model definitions for state-of-the-art large language models and making Megatron training as simple as Transformers.
57+
- 🎉 2026.03.30: MCore-Bridge is released! Providing Megatron-Core model definitions for state-of-the-art large models and making Megatron training as simple as Transformers.
5858

5959
## 🛠️ Installation
6060
To install using pip:

README_zh.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<p> -->
88

99
<p align="center">
10-
<b>为最先进的大语言模型提供 Megatron-Core 模型定义</b>
10+
<b>为最先进的大模型提供 Megatron-Core 模型定义</b>
1111
</p>
1212

1313
<p align="center">
@@ -21,17 +21,17 @@
2121
<img src="https://img.shields.io/badge/python-3.11-5be.svg">
2222
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
2323
<a href="https://github.com/NVIDIA/Megatron-LM/"><img src="https://img.shields.io/badge/megatron--core-%E2%89%A50.12-76B900.svg"></a>
24-
<a href="https://mcore-bridge.readthedocs.io/en/latest/"><img src="https://img.shields.io/badge/docs-latest-blue.svg"></a>
24+
<!-- <a href="https://mcore-bridge.readthedocs.io/en/latest/"><img src="https://img.shields.io/badge/docs-latest-blue.svg"></a> -->
2525
<a href="https://pypi.org/project/mcore-bridge/"><img src="https://badge.fury.io/py/mcore-bridge.svg"></a>
2626
<a href="https://github.com/modelscope/mcore-bridge/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/mcore-bridge"></a>
2727
<a href="https://pepy.tech/project/mcore-bridge"><img src="https://pepy.tech/badge/mcore-bridge"></a>
2828
<a href="https://github.com/modelscope/mcore-bridge/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
2929
</p>
3030

3131

32-
<p align="center">
32+
<!-- <p align="center">
3333
<a href="https://mcore-bridge.readthedocs.io/en/latest/">English Documentation</a> &nbsp | &nbsp <a href="https://mcore-bridge.readthedocs.io/zh-cn/latest/">中文文档</a> &nbsp
34-
</p>
34+
</p> -->
3535

3636
## 📖 目录
3737
- [用户群](#-用户群)
@@ -53,7 +53,7 @@
5353
## 📝 简介
5454

5555
## 🎉 新闻
56-
- 🎉 2026.04.01: MCore-Bridge 正式发布!为最先进的大语言模型提供 Megatron-Core 模型定义,让 Megatron 训练像 Transformers 一样简单。
56+
- 🎉 2026.03.30: MCore-Bridge 正式发布!为最先进的大模型提供 Megatron-Core 模型定义,让 Megatron 训练像 Transformers 一样简单。
5757

5858
## 🛠️ 安装
5959
使用pip进行安装:

src/mcore_bridge/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
from .config import ModelConfig, hf_to_mcore_config
1313
from .model import get_mcore_model
1414
from .tuners import LoraParallelLinear
15-
from .utils import get_logger, set_random_seed
15+
from .utils import get_logger, set_random_seed, split_cp_inputs, unwrap_model
1616
from .version import __release_datetime__, __version__
1717
else:
1818
_import_structure = {
1919
'bridge': ['GPTBridge'],
2020
'config': ['ModelConfig', 'hf_to_mcore_config'],
2121
'model': ['get_mcore_model'],
2222
'tuners': ['LoraParallelLinear'],
23-
'utils': ['get_logger', 'set_random_seed'],
23+
'utils': ['get_logger', 'set_random_seed', 'split_cp_inputs', 'unwrap_model'],
2424
'version': ['__release_datetime__', '__version__'],
2525
}
2626

src/mcore_bridge/bridge/gpt_bridge.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
import torch
66
import torch.distributed as dist
77
import torch.nn.functional as F
8-
import transformers
8+
from contextlib import contextmanager
99
from megatron.core import mpu
1010
from packaging import version
1111
from peft import PeftModel
1212
from peft.utils import ModulesToSaveWrapper
1313
from tqdm import tqdm
14-
from typing import List, Optional, Union
14+
from transformers import PreTrainedModel
15+
from transformers.utils import ContextManagers
16+
from typing import Callable, List, Optional, Union
1517

1618
from mcore_bridge.tuners import LoraParallelLinear
1719
from mcore_bridge.utils import (MxFp4Dequantizer, SafetensorLazyLoader, StreamingSafetensorSaver, deep_getattr,
@@ -66,7 +68,6 @@ def __init__(self, config):
6668
self.pp_group = mpu.get_pipeline_model_parallel_group()
6769
self.etp_group = mpu.get_expert_tensor_parallel_group()
6870
self.ep_group = mpu.get_expert_model_parallel_group()
69-
self.is_transformers_5 = version.parse(transformers.__version__) >= version.parse('5.0.0.dev')
7071
self.tp_rank = mpu.get_tensor_model_parallel_rank()
7172
self.pp_rank = mpu.get_pipeline_model_parallel_rank()
7273
self.etp_rank = mpu.get_expert_tensor_parallel_rank()
@@ -1615,7 +1616,14 @@ def _convert_mtp_layer(self, lm_model, hf_state_dict, hf_prefix: str, layer_idx:
16151616
hf_state_dict.update(origin_hf_state_dict)
16161617
return hf_state_dict
16171618

1618-
def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False, adapter_name: str = 'default'):
1619+
def load_weights(
1620+
self,
1621+
mg_models,
1622+
hf_model_dir: str,
1623+
peft_format: bool = False,
1624+
adapter_name: str = 'default',
1625+
converter: Optional[Callable] = None,
1626+
):
16191627
"""Load weights from safetensors (HuggingFace) format into Megatron model.
16201628
16211629
Args:
@@ -1624,24 +1632,38 @@ def load_weights(self, mg_models, hf_model_dir: str, peft_format: bool = False,
16241632
peft_format: Whether the weights are in PEFT (LoRA, etc.) format. Defaults to False.
16251633
If True, loads LoRA delta weights. If False, loads the full model weights.
16261634
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1635+
converter: Used to perform key-value conversion on the newly loaded state_dict.
16271636
"""
16281637
self._peft_format = peft_format
16291638
self._adapter_name = adapter_name
16301639
mg_models = unwrap_model(mg_models)
16311640
self._disable_tqdm = False
16321641
with torch.no_grad(), SafetensorLazyLoader(hf_model_dir, peft_format=peft_format) as loader:
16331642
state_dict = loader.get_state_dict()
1643+
if converter:
1644+
new_state_dict = {}
1645+
for k, v in state_dict.items():
1646+
kv = converter(k, v)
1647+
if kv is None:
1648+
continue
1649+
k, v = kv
1650+
new_state_dict[k] = v
1651+
state_dict = new_state_dict
16341652
hf_prefix = 'base_model.model.' if peft_format else ''
16351653
for mg_model in mg_models:
16361654
list(self._convert([mg_model], state_dict, hf_prefix, True, 'Loading: '))
16371655

1638-
def export_weights(self,
1639-
mg_models,
1640-
target_device=None,
1641-
only_master_rank: bool = False,
1642-
peft_format: bool = False,
1643-
tqdm_desc: str = 'Exporting: ',
1644-
disable_tqdm: bool = True):
1656+
def export_weights(
1657+
self,
1658+
mg_models,
1659+
target_device=None,
1660+
only_master_rank: bool = False,
1661+
peft_format: bool = False,
1662+
adapter_name: str = 'default',
1663+
converter: Optional[Callable] = None,
1664+
tqdm_desc: str = 'Exporting: ',
1665+
disable_tqdm: bool = True,
1666+
):
16451667
"""Export Megatron model weights to safetensors (HuggingFace) format as a generator.
16461668
16471669
This method yields weight tensors one by one for streaming save operations or RL weight synchronization,
@@ -1654,6 +1676,8 @@ def export_weights(self,
16541676
peft_format: Whether to export in PEFT (LoRA, etc.) format. Defaults to False.
16551677
- If True, exports only LoRA delta weights. If False, exports the complete model weights
16561678
(e.g., after merge-lora or full-parameter fine-tuning).
1679+
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1680+
converter: Used to perform key-value conversion on the newly exported state_dict.
16571681
tqdm_desc: Description text for the progress bar. Defaults to 'Exporting: '.
16581682
disable_tqdm: Whether to disable the tqdm progress bar. Defaults to True.
16591683
@@ -1663,8 +1687,8 @@ def export_weights(self,
16631687
self._target_device = target_device
16641688
self._only_master_rank = only_master_rank
16651689
self._peft_format = peft_format
1690+
self._adapter_name = adapter_name
16661691
self._disable_tqdm = disable_tqdm
1667-
self._adapter_name = 'default'
16681692
self._peft_target_modules = set()
16691693
self._peft_modules_to_save = set()
16701694
hf_prefix = 'base_model.model.' if peft_format else ''
@@ -1674,13 +1698,21 @@ def export_weights(self,
16741698
mg_models[i] = mg_model.model
16751699
self.config = mg_models[0].config
16761700
with torch.no_grad():
1677-
yield from self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc)
1701+
for k, v in self._convert(mg_models, {}, hf_prefix, False, tqdm_desc=tqdm_desc):
1702+
if converter:
1703+
kv = converter(k, v)
1704+
if kv is None:
1705+
continue
1706+
k, v = kv
1707+
yield k, v
16781708

16791709
def save_weights(
16801710
self,
16811711
mg_models,
16821712
output_dir: str,
16831713
peft_format: bool = False,
1714+
adapter_name: str = 'default',
1715+
converter: Optional[Callable] = None,
16841716
max_shard_size: str = '5GB',
16851717
) -> None:
16861718
"""Save Megatron model checkpoint in safetensors (HuggingFace) format.
@@ -1695,6 +1727,8 @@ def save_weights(
16951727
peft_format: Whether to save in PEFT (LoRA, etc.) format. Defaults to False.
16961728
If True, saves LoRA delta weights. If False, saves the complete model weights
16971729
(e.g., after merge-lora or full-parameter fine-tuning).
1730+
adapter_name: Name of the adapter for PEFT models. Defaults to 'default'.
1731+
converter: Used to perform key-value conversion on the newly exported state_dict.
16981732
max_shard_size: Maximum size of a single storage file, default is '5GB'.
16991733
"""
17001734
gc_collect()
@@ -1705,12 +1739,51 @@ def save_weights(
17051739
target_device='cpu',
17061740
only_master_rank=True,
17071741
peft_format=peft_format,
1742+
adapter_name=adapter_name,
1743+
converter=converter,
17081744
tqdm_desc='Saving: ',
17091745
disable_tqdm=False):
17101746
saver.add_tensor(k, v)
17111747
saver.finalize()
17121748
dist.barrier() # Ensure all weights are saved completely
17131749

1750+
@contextmanager
1751+
def _patch_hf_initialize_weight(self):
1752+
1753+
_origin_initialize_weight = PreTrainedModel._initialize_weights
1754+
1755+
def _initialize_weight(self, *args, **kwargs):
1756+
return
1757+
1758+
PreTrainedModel._initialize_weights = _initialize_weight
1759+
try:
1760+
yield
1761+
finally:
1762+
PreTrainedModel._initialize_weights = _origin_initialize_weight
1763+
1764+
@contextmanager
1765+
def _patch_device_meta(self, model_cls):
1766+
__origin_init__ = model_cls.__init__
1767+
1768+
def __init__(self, *args, **kwargs):
1769+
with torch.device('meta'):
1770+
__origin_init__(self, *args, **kwargs)
1771+
1772+
model_cls.__init__ = __init__
1773+
1774+
try:
1775+
yield
1776+
finally:
1777+
model_cls.__init__ = __origin_init__
1778+
1779+
def _get_meta_model_context(self, ignore_init_model_cls=None):
1780+
ignore_init_model_cls = ignore_init_model_cls or []
1781+
if not isinstance(ignore_init_model_cls, list):
1782+
ignore_init_model_cls = [ignore_init_model_cls]
1783+
context_list = [self._patch_device_meta(model_cls) for model_cls in ignore_init_model_cls]
1784+
context_list.append(self._patch_hf_initialize_weight())
1785+
return ContextManagers(context_list)
1786+
17141787

17151788
class MultimodalGPTBridge(GPTBridge):
17161789
hf_layers_prefix = 'model.language_model.layers'

src/mcore_bridge/config/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ def __post_init__(self):
285285

286286
if self.add_bias_linear:
287287
self.add_qkv_bias = True
288+
self.batch_p2p_comm = not self.overlap_p2p_comm
288289
if self.swiglu:
289290
self.activation_func = F.silu
290291
self.gated_linear_unit = True

src/mcore_bridge/model/mm_gpts/internvl.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import importlib
23
import torch
34
from torch import nn
4-
from transformers import AutoModel, PretrainedConfig
5+
from transformers import AutoModel, AutoTokenizer, PretrainedConfig
56
from transformers.dynamic_module_utils import get_class_from_dynamic_module
67

78
from mcore_bridge.bridge import GPTBridge, MultimodalGPTBridge
@@ -18,6 +19,23 @@ class InternvlBridge(GPTBridge):
1819
hf_lm_head_key = 'language_model.lm_head.weight'
1920
hf_score_key = 'language_model.score.weight'
2021

22+
def get_hf_meta_model(self):
23+
model_cls = []
24+
class_names = ['Qwen2ForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM', 'GptOssForCausalLM']
25+
module = importlib.import_module('transformers')
26+
for cls_name in class_names:
27+
try:
28+
model_cls.append(getattr(module, cls_name))
29+
except (ImportError, AttributeError):
30+
pass
31+
contexts = self._get_meta_model_context(model_cls)
32+
hf_config = self.config.hf_config
33+
model_cls = get_class_from_dynamic_module('modeling_internvl_chat.InternVLChatModel', hf_config.name_or_path)
34+
with contexts:
35+
model = model_cls(hf_config)
36+
model._auto_class = 'AutoModelForCausalLM'
37+
return model
38+
2139

2240
class InternvlVit(HuggingFaceVit):
2341
module_mapping = {'vision_model': 'vision_model', 'mlp1': 'mlp1'}
@@ -33,7 +51,6 @@ def prepare_attn_impl(self):
3351
self.hf_config.vision_config.use_flash_attn = use_flash_attn
3452

3553
def prepare_model(self, hf_config: PretrainedConfig):
36-
from transformers import AutoProcessor
3754
llm_model_type = self.config.llm_model_type
3855
if llm_model_type not in ['qwen2', 'qwen3', 'qwen3_moe', 'gpt_oss']:
3956
raise ValueError(f'{llm_model_type} is not supported for internvl_chat model')
@@ -52,7 +69,7 @@ def prepare_model(self, hf_config: PretrainedConfig):
5269
self.select_layer = hf_config.select_layer
5370
self.downsample_ratio = hf_config.downsample_ratio
5471
self.ps_version = hf_config.ps_version
55-
self.processor = AutoProcessor.from_pretrained(hf_config.name_or_path, trust_remote_code=True)
72+
self.tokenizer = AutoTokenizer.from_pretrained(hf_config.name_or_path, trust_remote_code=True)
5673

5774
def get_inputs_embeds(self, inputs_embeds, **kwargs):
5875
input_ids = kwargs['input_ids']
@@ -63,7 +80,7 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs):
6380
inputs_embeds = inputs_embeds + vit_embeds.mean() * 0.
6481
else:
6582
vit_embeds = self.extract_feature(pixel_values.to(self.vision_model.dtype))
66-
selected = (input_ids == self.processor.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
83+
selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
6784
inputs_embeds = inputs_embeds.clone()
6885
inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1]).to(dtype=inputs_embeds.dtype)
6986
return inputs_embeds

0 commit comments

Comments
 (0)