Skip to content

Commit 2dc426f

Browse files
authored
[module] clean mtp code (#11)
1 parent c94a2ad commit 2dc426f

File tree

6 files changed

+67
-54
lines changed

6 files changed

+67
-54
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
## 📖 Table of Contents
3737
- [Groups](#-Groups)
38-
- [Introduction](#-introduction)
3938
- [News](#-news)
4039
- [Installation](#%EF%B8%8F-installation)
4140
- [Quick Start](#-quick-Start)
@@ -51,8 +50,6 @@ You can contact us and communicate with us by adding our group:
5150
|:-------------------------:|
5251
| <img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/wechat/megatron.png" width="200" height="200"> |
5352

54-
## 📝 Introduction
55-
5653
## 🎉 News
5754
- 🎉 2026.03.30: MCore-Bridge is released! Providing Megatron-Core model definitions for state-of-the-art large models and making Megatron training as simple as Transformers.
5855

@@ -80,6 +77,8 @@ uv pip install -e . --torch-backend=auto
8077

8178
## 🚀 Quick Start
8279

80+
How to use MCore-Bridge for training can be referred to the [ms-swift project](https://swift.readthedocs.io/en/latest/Megatron-SWIFT/Mcore-Bridge.html). Here we introduce how to use MCore-Bridge programmatically.
81+
8382
You need to create the following file (test.py), then run `CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py`. Below is sample code demonstrating how to use Mcore-Bridge for model creation, weight loading, export, and saving.
8483

8584
The saved model can be used for inference by referring to the [example code in the model card](https://modelscope.cn/models/Qwen/Qwen3.5-35B-A3B).

README_zh.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636
## 📖 目录
3737
- [用户群](#-用户群)
38-
- [简介](#-简介)
3938
- [新闻](#-新闻)
4039
- [安装](#%EF%B8%8F-安装)
4140
- [快速开始](#-快速开始)
@@ -50,8 +49,6 @@
5049
|:-------------------------:|
5150
| <img src="https://raw.githubusercontent.com/modelscope/ms-swift/main/docs/resources/wechat/megatron.png" width="200" height="200"> |
5251

53-
## 📝 简介
54-
5552
## 🎉 新闻
5653
- 🎉 2026.03.30: MCore-Bridge 正式发布!为最先进的大模型提供 Megatron-Core 模型定义,让 Megatron 训练像 Transformers 一样简单。
5754

@@ -79,6 +76,8 @@ uv pip install -e . --torch-backend=auto
7976

8077
## 🚀 快速开始
8178

79+
如何使用MCore-Bridge进行训练可以参考[ms-swift项目](https://swift.readthedocs.io/zh-cn/latest/Megatron-SWIFT/Mcore-Bridge.html)。这里介绍如何使用代码方式使用Mcore-Bridge。
80+
8281
你需要创建以下文件(test.py),然后运行`CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 test.py`。以下为使用Mcore-Bridge进行创建模型、权重加载、导出、保存的示例代码。
8382

8483
保存的模型,可以参考[模型卡片的示例代码](https://modelscope.cn/models/Qwen/Qwen3.5-35B-A3B)进行推理。

src/mcore_bridge/config/model_config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2+
import copy
23
import os
34
import re
45
import torch.nn.functional as F
@@ -346,3 +347,14 @@ def _check_npu(self):
346347
f'expert_model_parallel_size={expert_model_parallel_size}. '
347348
f'Please set expert_model_parallel_size (EP) to {required_ep} '
348349
f'(num_experts / {MAX_NPU_EXPERTS_PER_EP}) or higher.')
350+
351+
def __deepcopy__(self, memo):
352+
cls = self.__class__
353+
new_obj = cls.__new__(cls)
354+
memo[id(self)] = new_obj
355+
for k, v in self.__dict__.items():
356+
if k == 'bridge':
357+
setattr(new_obj, k, v)
358+
else:
359+
setattr(new_obj, k, copy.deepcopy(v, memo))
360+
return new_obj

src/mcore_bridge/model/gpt_model.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def _apply_rotary_pos_emb_bshd(
158158
rotary_interleaved: bool = False,
159159
multi_latent_attention: bool = False, # not use
160160
mscale: float = 1.0,
161+
**kwargs,
161162
) -> torch.Tensor:
162163
"""Apply rotary positional embedding to input tensor T.
163164
@@ -390,6 +391,8 @@ def _postprocess(
390391
output_weight = None
391392
if self.share_embeddings_and_output_weights:
392393
output_weight = self.shared_embedding_or_output_weight()
394+
if self.config.is_multimodal and self.config.context_parallel_size > 1:
395+
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
393396

394397
if self.mtp_process:
395398
hidden_states = self.mtp(
@@ -406,55 +409,52 @@ def _postprocess(
406409
embedding=self.embedding,
407410
**(extra_block_kwargs or {}),
408411
)
412+
mtp_labels = labels.clone()
409413
hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0)
410414
hidden_states = hidden_states_list[0]
411-
412-
if labels is not None:
413-
mtp_labels = labels.clone()
414-
if loss_mask is None:
415-
# if loss_mask is not provided, use all ones as loss_mask
416-
if packed_seq_params is None:
417-
loss_mask = torch.ones_like(mtp_labels)
418-
else:
419-
loss_mask = mtp_labels.new_ones((1, packed_seq_params.cu_seqlens_q[-1]))
420-
cu_seqlens = packed_seq_params.cu_seqlens_q if packed_seq_params is not None else None
421-
for mtp_layer_number in range(self.config.mtp_num_layers):
422-
# output
423-
mtp_logits, _ = self.output_layer(
424-
hidden_states_list[mtp_layer_number + 1],
425-
weight=output_weight,
426-
runtime_gather_output=runtime_gather_output,
415+
if loss_mask is None:
416+
# if loss_mask is not provided, use all ones as loss_mask
417+
loss_mask = torch.ones_like(mtp_labels)
418+
for mtp_layer_number in range(self.config.mtp_num_layers):
419+
# output
420+
mtp_logits, _ = self.output_layer(
421+
hidden_states_list[mtp_layer_number + 1],
422+
weight=output_weight,
423+
runtime_gather_output=runtime_gather_output,
424+
)
425+
# Calc loss for the current Multi-Token Prediction (MTP) layers.
426+
mtp_labels, _ = roll_tensor(
427+
mtp_labels,
428+
shifts=-1,
429+
dims=-1,
430+
cp_group=self.cp_group,
431+
packed_seq_params=packed_seq_params,
432+
)
433+
loss_mask, _ = roll_tensor(
434+
loss_mask,
435+
shifts=-1,
436+
dims=-1,
437+
cp_group=self.cp_group,
438+
packed_seq_params=packed_seq_params,
439+
)
440+
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
441+
loss_mask_ = (loss_mask & (mtp_labels != -100))
442+
num_tokens = loss_mask_.sum()
443+
mtp_loss = loss_mask_ * mtp_loss
444+
if self.training:
445+
mtp_loss_for_log = (
446+
torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
447+
MTPLossLoggingHelper.save_loss_to_tracker(
448+
mtp_loss_for_log,
449+
mtp_layer_number,
450+
self.config.mtp_num_layers,
451+
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
427452
)
428-
# Calc loss for the current Multi-Token Prediction (MTP) layers.
429-
mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group)
430-
if cu_seqlens is None:
431-
loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group)
432-
loss_mask_ = loss_mask
433-
else:
434-
loss_mask[:, cu_seqlens[:-1]] = 0
435-
loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1)
436-
if self.config.context_parallel_size > 1:
437-
loss_mask_ = split_cp_inputs(loss_mask, cu_seqlens, dim=1)
438-
else:
439-
loss_mask_ = loss_mask.clone()
440-
mtp_loss = self.compute_language_model_loss(mtp_labels, mtp_logits)
441-
loss_mask_ = loss_mask_ & (mtp_labels != -100)
442-
mtp_loss = loss_mask_ * mtp_loss
443-
num_tokens = loss_mask_.sum()
444-
if self.training:
445-
mtp_loss_for_log = (
446-
torch.sum(mtp_loss) / num_tokens if num_tokens > 0 else mtp_loss.new_tensor(0.0))
447-
MTPLossLoggingHelper.save_loss_to_tracker(
448-
mtp_loss_for_log,
449-
mtp_layer_number,
450-
self.config.mtp_num_layers,
451-
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
452-
)
453-
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
454-
if self.config.calculate_per_token_loss:
455-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
456-
else:
457-
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
453+
mtp_loss_scale = self.config.mtp_loss_scaling_factor / self.config.mtp_num_layers
454+
if self.config.calculate_per_token_loss:
455+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss)
456+
else:
457+
hidden_states = MTPLossAutoScaler.apply(hidden_states, mtp_loss_scale * mtp_loss / num_tokens)
458458
sequence_parallel_override = False
459459
if in_inference_mode and inference_context.materialize_only_last_token_logits:
460460
if inference_context.is_static_batching():

src/mcore_bridge/patcher.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ def _apply_rotary_pos_emb_thd(
608608
multi_latent_attention: bool = False,
609609
mscale: float = 1.0,
610610
cp_group: torch.distributed.ProcessGroup = None,
611+
**kwargs,
611612
) -> torch.Tensor:
612613
"""A baseline implementation of applying RoPE for `thd` format.
613614
@@ -629,7 +630,8 @@ def _apply_rotary_pos_emb_thd(
629630
use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item()
630631
if not use_batched_rope:
631632
logger.warning_once('Using non-batched RoPE, which may affect performance.')
632-
kwargs = {'cp_group': cp_group} if mcore_013 else {}
633+
if mcore_013:
634+
kwargs['cp_group'] = cp_group
633635
return _origin_apply_rotary_pos_emb_thd(
634636
t,
635637
cu_seqlens,
@@ -646,6 +648,7 @@ def _apply_rotary_pos_emb_thd(
646648
rotary_interleaved=rotary_interleaved,
647649
multi_latent_attention=multi_latent_attention,
648650
mscale=mscale,
651+
**kwargs,
649652
).squeeze(1)
650653

651654
rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd

src/mcore_bridge/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Make sure to modify __release_datetime__ to release time when making official release.
2-
__version__ = '1.0.1.dev0'
2+
__version__ = '1.1.0.dev0'
33
# default release datetime for branches under active development is set
44
# to be a time far-far-away-into-the-future
55
__release_datetime__ = '2099-12-31 23:59:59'

0 commit comments

Comments
 (0)