Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6155a65
use mcore config_converer and model_initializer for more types of models
ISEEKYAN Apr 13, 2025
8869168
remove megatron_config from actor/critic
ISEEKYAN Apr 13, 2025
9216811
reward model use gptmodel api, clean megatron_worker
ISEEKYAN Apr 13, 2025
6c46c2a
mcore model_forward for registry
ISEEKYAN Apr 13, 2025
a9c21cf
Merge branch 'main' into mcore_refactor
ISEEKYAN Apr 14, 2025
e709dc3
(WIP) support qwen2moe
ISEEKYAN Apr 16, 2025
0775d36
qwen2moe config converter and weight converter
ISEEKYAN Apr 17, 2025
6113b10
add scripts to run qwen1.5moe_a2.7b
ISEEKYAN Apr 17, 2025
bbf41b6
Merge branch 'main' into mcore_qwen2moe
ISEEKYAN Apr 17, 2025
5f8d8a0
format
ISEEKYAN Apr 18, 2025
d2376ec
update scripts
ISEEKYAN Apr 18, 2025
39e0658
Merge branch 'main' into mcore_qwen2moe
ISEEKYAN Apr 18, 2025
57d9671
fix for pre-commit
ISEEKYAN Apr 18, 2025
5181a99
Merge branch 'main' into mcore_qwen2moe
ISEEKYAN Apr 19, 2025
7b66d82
fix bug of merge
ISEEKYAN Apr 19, 2025
941ab95
compatible to mcore 0.12
ISEEKYAN Apr 19, 2025
267a119
WIP support moonlight
ISEEKYAN Apr 21, 2025
8801841
fix
ISEEKYAN Apr 28, 2025
e5d6ca0
typo
ISEEKYAN Apr 28, 2025
7f84424
Merge branch 'main' into mcore_moonlight
ISEEKYAN Apr 28, 2025
ae550a8
add scripts
ISEEKYAN Apr 28, 2025
a66c823
Merge branch 'mcore_moonlight' into tmp_merge
spacegoing May 22, 2025
dce2c40
[Fix] config_converter signature
spacegoing May 23, 2025
3e09ccf
[Fix] restore trust remote code arg
spacegoing May 23, 2025
b1df278
[Fix] adapt to use base class in config_converter
spacegoing May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions examples/ppo_trainer/run_moonlight16b_a3b_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
set -x

# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping


# 0. download the model
huggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct

# 1. convert the model to mcore format
# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path
HF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct
DIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct
python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH --trust_remote_code


# 2. run the script
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
train_files=$gsm8k_train_path
test_files=$gsm8k_test_path

NODES=4
PP=2
TP=4
CP=1
VLLM_TP=4

# RAY_ADDRESS='auto' ray job submit --working-dir . --
python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\
algorithm.adv_estimator=gae \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
+data.trust_remote_code=True \
actor_rollout_ref.model.path=$LLM \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
critic.optim.lr=1e-5 \
critic.model.path=$LLM \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_megatron_gsm8k_examples' \
trainer.experiment_name='moonlight_freeze_moe_router' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=$NODES \
trainer.save_freq=-1 \
trainer.test_freq=5 \
+actor_rollout_ref.model.trust_remote_code=True \
+critic.model.trust_remote_code=True \
+actor_rollout_ref.megatron.extra.num_layers_in_last_pipeline_stage=13 \
+critic.megatron.extra.num_layers_in_last_pipeline_stage=13 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \
critic.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \
critic.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
critic.megatron.use_dist_checkpointing=True \
actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
trainer.val_before_train=False \
trainer.total_epochs=100 $@
82 changes: 76 additions & 6 deletions scripts/converter_hf_to_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _init_args():
parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model")
parser.add_argument("--use_cpu_initialization", action="store_true", help="Whether to use cpu initialization")
parser.add_argument("--test", action="store_true", help="Whether to test the conversion")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to trust remote hf code")
args = parser.parse_args()
return args

Expand Down Expand Up @@ -120,7 +121,7 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
layer.self_attention.linear_qkv.bias.copy_(qkv_bias)

if hasattr(hf_layer.self_attn, "q_norm"):
layer.self_attention.q_layernorm.weight.copy_(hf_layer.self_attn.q_norm.weight.data)
layer.self_attention.k_layernorm.weight.copy_(hf_layer.self_attn.k_norm.weight.data)
Expand All @@ -145,7 +146,72 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
model.output_layer.weight.copy_(hf_model.lm_head.weight)


def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False):
@torch.no_grad()
def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig):
warnings.warn("MTP model is not supported yet", stacklevel=2)

def safe_copy(
src_tensor: torch.Tensor,
dst_tensor: torch.Tensor,
skip_dtype_assert: bool = False,
):
if not skip_dtype_assert:
if src_tensor.dtype != dst_tensor.dtype:
raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}")
assert src_tensor.shape == dst_tensor.shape
dst_tensor.data.copy_(src_tensor.data)
return src_tensor.numel()

model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)):
print(layer_idx)
layer.input_layernorm.weight.copy_(hf_layer.input_layernorm.weight)

if hf_config.q_lora_rank is None:
layer.self_attention.linear_q_proj.weight.copy_(hf_layer.self_attn.q_proj.weight)
else:
layer.self_attention.linear_q_down_proj.weight.copy_(hf_layer.self_attn.q_a_proj.weight)
layer.self_attention.linear_q_up_proj.weight.copy_(hf_layer.self_attn.q_b_proj.weight)
layer.self_attention.linear_q_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.q_a_layernorm.weight)

layer.self_attention.linear_kv_down_proj.weight.copy_(hf_layer.self_attn.kv_a_proj_with_mqa.weight)
layer.self_attention.linear_kv_up_proj.weight.copy_(hf_layer.self_attn.kv_b_proj.weight)
layer.self_attention.linear_kv_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.kv_a_layernorm.weight)
layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight)

if not hasattr(layer.mlp, "router"):
layer.mlp.linear_fc1.layer_norm_weight.copy_(hf_layer.post_attention_layernorm.weight)
layer.mlp.linear_fc1.weight.copy_(torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]))
layer.mlp.linear_fc2.weight.copy_(hf_layer.mlp.down_proj.weight)
else:
layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight)
# NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \
# recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%)
safe_copy(hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True)
if tfconfig.moe_grouped_gemm:
for i, hf_expert in enumerate(hf_layer.mlp.experts):
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(i))
linear_fc1_weighti.copy_(fc1_weight)
linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(i))
linear_fc2_weighti.copy_(hf_expert.down_proj.weight)
else:
for i, hf_expert in enumerate(hf_layer.mlp.experts):
expert = layer.mlp.experts.local_experts[i]
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
expert.linear_fc1.weight.copy_(fc1_weight)
expert.linear_fc2.weight.copy_(hf_expert.down_proj.weight)
layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight)
shared_fc1_weight = torch.cat([hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight])
layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight)
layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_experts.down_proj.weight)

model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight)
if not hf_config.tie_word_embeddings:
model.output_layer.weight.copy_(hf_model.lm_head.weight)


def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False):
os.makedirs(output_path, exist_ok=True)
if len(os.listdir(output_path)) > 0 and not test:
print(f"Output path {output_path} is not empty, skipping conversion")
Expand All @@ -166,7 +232,7 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
model_parallel_cuda_manual_seed(0)

# init hf config
hf_config = AutoConfig.from_pretrained(hf_model_path)
hf_config = AutoConfig.from_pretrained(hf_model_path, trust_remote_code=trust_remote_code)
print(hf_config, flush=True)

cfg = Config()
Expand Down Expand Up @@ -200,11 +266,15 @@ def megatron_model_provider(pre_process, post_process):
warnings.simplefilter("ignore")

# init hf model
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16)
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code)
hf_state_dict = hf_model.state_dict()

# load hf state dict to megatron model
if "Qwen2MoeForCausalLM" in hf_config.architectures:
if "DeepseekV3ForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)
elif "Qwen2MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
elif "Qwen3MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
Expand Down Expand Up @@ -232,4 +302,4 @@ def megatron_model_provider(pre_process, post_process):

if __name__ == "__main__":
args = _init_args()
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test)
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code)
79 changes: 78 additions & 1 deletion verl/models/mcore/config_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
# convert huggingface config to mcore transformer config


from dataclasses import asdict
import torch
import torch.nn.functional as F
from megatron.core.transformer import MLATransformerConfig, TransformerConfig
from megatron.core.transformer.enums import AttnBackend
from transformers import PretrainedConfig


Expand Down Expand Up @@ -182,7 +184,82 @@ def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype,

def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> MLATransformerConfig:
# DeepseekV3ForCausalLM
raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet")
from megatron.core import parallel_state as mpu

moe_layer_freq = [1] * hf_config.num_hidden_layers
for i in range(hf_config.first_k_dense_replace):
moe_layer_freq[i] = 0

base_config = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
activation_func=F.silu,
use_cpu_initialization=False,
add_bias_linear=False,
# attention_backend=AttnBackend.flash,
attention_backend=AttnBackend.fused,
qk_layernorm=True,
# moe specific
moe_ffn_hidden_size=hf_config.moe_intermediate_size,
moe_router_bias_update_rate=0.001,
moe_router_enable_expert_bias=True,
moe_router_topk=hf_config.num_experts_per_tok,
num_moe_experts=hf_config.n_routed_experts,
moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts,
moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001),
moe_router_load_balancing_type="seq_aux_loss",
moe_shared_expert_overlap=True,
# moe_permute_fusion=True, # need TE 2.1+
moe_grouped_gemm=True,
moe_router_score_function="sigmoid",
moe_router_pre_softmax=True,
moe_router_topk_scaling_factor=hf_config.routed_scaling_factor,
moe_layer_freq=moe_layer_freq,
persist_layer_norm=True,
bias_activation_fusion=True,
bias_dropout_fusion=True,
**override_transformer_config_kwargs,
)
base_config_dict = asdict(base_config)

# transformer config default multi_latent_attention = False
base_config_dict.update({"multi_latent_attention": True})

mla_rope_config = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 1,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "rope",
}
if "rope_scaling" in hf_config and hf_config.rope_scaling is not None:
mla_rope_config.update(hf_config.rope_scaling)
transformer_config = MLATransformerConfig(
# MLA
q_lora_rank=hf_config.q_lora_rank,
kv_lora_rank=hf_config.kv_lora_rank,
qk_head_dim=hf_config.qk_nope_head_dim,
qk_pos_emb_head_dim=hf_config.qk_rope_head_dim,
v_head_dim=hf_config.v_head_dim,
rotary_base=hf_config.rope_theta,
rotary_scaling_factor=mla_rope_config["factor"],
rope_type=mla_rope_config["type"],
mscale=mla_rope_config["mscale"],
mscale_all_dim=mla_rope_config["mscale_all_dim"],
max_position_embeddings=mla_rope_config["original_max_position_embeddings"],
beta_fast=mla_rope_config["beta_fast"],
beta_slow=mla_rope_config["beta_slow"],
# mcore 0.12 moe
# moe_router_dtype="fp64",
# disable_bf16_reduced_precision_matmul=True,
# other
# deallocate_pipeline_outputs=True,
# gradient_accumulation_fusion=True,
**base_config_dict
)
return transformer_config


def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
Expand Down
21 changes: 21 additions & 0 deletions verl/models/mcore/model_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,27 @@ def initialize(self, **kwargs):
return model


class Dpskv3Model(BaseModelInitializer):
"""Initializer for Deepseek V3 MOE models."""

def get_transformer_layer_spec(self):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
return transformer_layer_spec

def initialize(self, **kwargs):
freeze_moe_router = kwargs.get("freeze_moe_router", True)
if freeze_moe_router:
self.tfconfig.moe_router_load_balancing_type = "none"

model = super().initialize(**kwargs)
if freeze_moe_router:
for layer in model.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router.weight.requires_grad = False
return model


class Qwen25VLModel(BaseModelInitializer):
"""Initializer for Qwen2.5 VL models."""

Expand Down
5 changes: 4 additions & 1 deletion verl/models/mcore/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@
Qwen2MoEModel,
Qwen3MoEModel,
Qwen25VLModel,
Dpskv3Model
)
from .weight_converter import (
McoreToHFWeightConverterDense,
McoreToHFWeightConverterMixtral,
McoreToHFWeightConverterQwen2Moe,
McoreToHFWeightConverterQwen3Moe,
McoreToHFWeightConverterDpskv3,
)


Expand Down Expand Up @@ -83,7 +85,7 @@ class SupportedModel(Enum):
SupportedModel.QWEN2: DenseModel,
SupportedModel.QWEN2_MOE: Qwen2MoEModel,
SupportedModel.MIXTRAL: MixtralModel,
SupportedModel.DEEPSEEK_V3: DenseModel,
SupportedModel.DEEPSEEK_V3: Dpskv3Model,
SupportedModel.QWEN2_5_VL: Qwen25VLModel,
SupportedModel.LLAMA4: DenseModel,
SupportedModel.QWEN3: DenseModel,
Expand Down Expand Up @@ -111,6 +113,7 @@ class SupportedModel(Enum):
SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,
SupportedModel.QWEN3: McoreToHFWeightConverterDense,
SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3,
}


Expand Down
4 changes: 4 additions & 0 deletions verl/models/mcore/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,7 @@ def merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_valu

def merge_megatron_ckpt_gptmodel_mixtral(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented")


def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented")
Loading