From a44f7f9dd1b8cd148c27e7e24684a4c1533c412e Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Tue, 16 Sep 2025 20:56:48 +0800 Subject: [PATCH 01/13] Support groupgemm for qwen3-next; Support param sync in sglang for groupgemm --- chatlearn/models/fsdp_module.py | 19 +- chatlearn/models/patches/monkey_patch.py | 4 + .../transformers/qwen3_next_moe_patch.py | 345 ++++++++++++++++++ 3 files changed, 364 insertions(+), 4 deletions(-) create mode 100644 chatlearn/models/patches/transformers/qwen3_next_moe_patch.py diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index fdbf6a35..c8054216 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -236,6 +236,7 @@ def model_setup(self): # When meta_init is enabled, we only load checkpoint on rank 0 meta_init = self.module_args.meta_init and local_rank != 0 model = self.create_model(args.load, torch_dtype=torch.bfloat16, meta_init=meta_init) + self.model_config = model.config if self.module_args.groupgemm: apply_group_gemm(model) dist.barrier() @@ -303,7 +304,6 @@ def model_setup(self): torch.cuda.synchronize() for name, buf in model.named_buffers(): dist.broadcast(buf, src=0) - self.model = model self.model.to(torch.float32) @@ -324,7 +324,7 @@ def model_setup(self): del full_state self.offload() - def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]: + def get_fsdp_param_name(self, block_size=300_000_000) -> List[List]: name_list = [] param_cnt = 0 current_group = [] @@ -356,8 +356,19 @@ def get_weight_ipc_handles_by_name(self, block_name: List[str]): serialize_func = reduce_tensor if rollout_engine=='vllm' else MultiprocessingSerializer.serialize for name, param in self.model.named_parameters(): if name in block_name: - reduce_tensor_dict[name] = serialize_func(param.full_tensor().detach() \ - if isinstance(param, DTensor) else param.detach()) + if self.module_args.groupgemm and "group_mlp" in name: + # This model is using groupgemm for moe forward + param = param.full_tensor().detach() + num_experts = self.model_config.num_experts + #split_size = param.shape[0] // num_experts + param_per_expert = torch.chunk(param, num_experts, dim=0) + #param_per_expert = torch.split(param, split_size, dim=0) + for i in range(num_experts): + local_name = name.replace('group_mlp', f"experts.{i}") + reduce_tensor_dict[local_name] = serialize_func(param_per_expert[i]) + else: + reduce_tensor_dict[name] = serialize_func(param.full_tensor().detach() \ + if isinstance(param, DTensor) else param.detach()) if self.module_args.use_expandable_segments: torch.cuda.memory._set_allocator_settings("expandable_segments:True") return reduce_tensor_dict diff --git a/chatlearn/models/patches/monkey_patch.py b/chatlearn/models/patches/monkey_patch.py index 1e6bdabf..89d23579 100644 --- a/chatlearn/models/patches/monkey_patch.py +++ b/chatlearn/models/patches/monkey_patch.py @@ -32,6 +32,10 @@ def apply_group_gemm(model): from chatlearn.models.patches.transformers.qwen3_moe_patch import apply_group_gemm_patch \ # pylint: disable=import-outside-toplevel apply_group_gemm_patch(model) + elif model.config.architectures[0] == "Qwen3NextForCausalLM": + from chatlearn.models.patches.transformers.qwen3_next_moe_patch import apply_group_gemm_patch \ + # pylint: disable=import-outside-toplevel + apply_group_gemm_patch(model) else: raise ValueError(f"Unsupported model architecture: {model.config.architectures} for groupgemm patch") diff --git a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py new file mode 100644 index 00000000..2995a75e --- /dev/null +++ b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py @@ -0,0 +1,345 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""patches for qwen3-moe model""" +from concurrent.futures import ThreadPoolExecutor + +import torch +from torch import nn +import torch.nn.functional as F + +try: + from transformer_engine.pytorch.cpp_extensions import grouped_gemm +except ImportError: + from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm as grouped_gemm +from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace +from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_unpermute, +) +from transformers.activations import ACT2FN + +from chatlearn.utils import is_te_min_version + +class GroupGemm(torch.autograd.Function): + """ Autograd function for grouped gemm""" + @staticmethod + def forward( + ctx, + inp, + m_splits, + use_bias, + is_grad_enabled, + activation_dtype, + *weights_bias + ) -> torch.Tensor: + n_gemm = len(m_splits) + weights = weights_bias[:n_gemm] + bias = weights_bias[n_gemm:] + in_features = weights[0].shape[-1] + inputmats = torch.split(inp.view(-1, in_features), m_splits) + output_tensor = torch.empty( + [sum(m_splits), weights[0].shape[0]], + dtype=activation_dtype, + device=inputmats[0].device, + ) + grouped_gemm_kwargs = {'dtype': activation_dtype} + if is_te_min_version("2.0.0"): + grouped_gemm_kwargs = {'out_dtype': activation_dtype, 'm_splits': m_splits} + _ = grouped_gemm( + A=weights, + B=inputmats, + out=torch.split(output_tensor, m_splits), + workspaces=get_multi_stream_cublas_workspace(), + bias=bias, + use_bias=use_bias, + **grouped_gemm_kwargs + ) + if is_grad_enabled: + ctx.save_for_backward( + *inputmats, + *weights, + ) + ctx.m_splits = m_splits + ctx.num_gemm = n_gemm + ctx.activation_dtype = activation_dtype + ctx.use_bias = use_bias + ctx.inp_shape = inp.shape + return output_tensor.view(-1, *inp.shape[1:-1], output_tensor.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + with torch.cuda.nvtx.range("_GroupedLinear_backward"): + saved_tensors = ctx.saved_tensors + inputmats = saved_tensors[:ctx.num_gemm] + weights = saved_tensors[ctx.num_gemm:] + + grad_output = grad_output.contiguous() + grad_output_mats = torch.split( + grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits + ) + #dgrad + dgrad = torch.empty( + (sum(ctx.m_splits), weights[0].size(1)), + dtype=ctx.activation_dtype, + device=grad_output.device, + ) + grouped_gemm_kwargs = {'dtype': ctx.activation_dtype} + if is_te_min_version("2.0.0"): + grouped_gemm_kwargs = {'out_dtype': ctx.activation_dtype, 'm_splits': ctx.m_splits} + grouped_gemm( + A=weights, + B=grad_output_mats, + out=torch.split(dgrad, ctx.m_splits), + workspaces=get_multi_stream_cublas_workspace(), + layout="NN", + grad=True, + **grouped_gemm_kwargs + ) + + #wgrad + wgrad_list = [ + torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) + for w in weights + ] + _, grad_biases, _ = grouped_gemm( + A=inputmats, + B=grad_output_mats, + out=wgrad_list, + workspaces=get_multi_stream_cublas_workspace(), + layout="NT", + grad=True, + use_bias=ctx.use_bias, + **grouped_gemm_kwargs + ) + if not ctx.use_bias: + grad_biases = [None] + return ( + dgrad.view(ctx.inp_shape), + None, + None, + None, + None, + *wgrad_list, + *grad_biases, + ) + +class Qwen3NextMLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def grouped_linear(inp, m_splits, use_bias, is_grad_enabled, activation_dtype, weights_bias, num_experts): + weights_bias = torch.chunk(weights_bias, num_experts, dim=0) + output = GroupGemm.apply( + inp, + m_splits, + use_bias, + is_grad_enabled, + activation_dtype, + *weights_bias + ) + return output + + +class Linear(nn.Module): + """used for empty init gate_proj,up_proj,down_proj""" + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty((out_features, in_features), **factory_kwargs) + ) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + + +class MoeGroupMLP(nn.Module): + """ Group MLP Layer """ + def __init__(self, config, intermediate_size): + super().__init__() + self.num_experts = config.num_experts + hidden_size = config.hidden_size + + self.gate_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size * self.num_experts, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, router_weights, selected_experts, token_per_expert) -> torch.Tensor: + flat_seqlen = hidden_states.shape[0] + topk = router_weights.shape[1] + if is_te_min_version("2.1.0"): + grouped_input, row_id_map = moe_permute(hidden_states, selected_experts, map_type='index') + else: + grouped_input, row_id_map = moe_permute(hidden_states, selected_experts) + probs = router_weights.T.contiguous().view(-1, 1) + + token_per_expert = token_per_expert.tolist() + gate_output = self.act_fn(grouped_linear( + inp = grouped_input, + m_splits = token_per_expert, + use_bias = False, + is_grad_enabled = self.training, + activation_dtype = hidden_states.dtype, + weights_bias = self.gate_proj.weight, + num_experts=self.num_experts + )) + up_output = grouped_linear( + inp = grouped_input, + m_splits = token_per_expert, + use_bias = False, + is_grad_enabled = self.training, + activation_dtype = hidden_states.dtype, + weights_bias = self.up_proj.weight, + num_experts=self.num_experts + ) * gate_output + down_output = grouped_linear( + inp = up_output, + m_splits = token_per_expert, + use_bias = False, + is_grad_enabled = self.training, + activation_dtype = hidden_states.dtype, + weights_bias = self.down_proj.weight, + num_experts=self.num_experts + ) + if is_te_min_version("2.1.0"): + final_hidden_states = moe_unpermute(down_output, row_id_map, probs, map_type='index') + else: + final_hidden_states = moe_unpermute(down_output, row_id_map, probs) + final_hidden_states = final_hidden_states.view(topk, flat_seqlen, -1).permute(1,0,2) + final_hidden_states = torch.sum(final_hidden_states, dim=1).squeeze(1) + return final_hidden_states + +class Qwen3MoeSparseMoeBlock_Grouped(nn.Module): + """ MOE Block support grouped linear """ + def __init__(self, config): + super().__init__() + self.num_experts = config.num_experts + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + + # gating + self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) + self.group_mlp = MoeGroupMLP(config, config.moe_intermediate_size) + + self.shared_expert = Qwen3NextMLP(config, intermediate_size=config.shared_expert_intermediate_size) + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False) + + def topk_expert(self, logits): + routing_weights = F.softmax(logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + topk_map = torch.zeros_like(logits).int().scatter(1, selected_experts, 1).bool() + tokens_per_expert = topk_map.sum(dim=0) + return routing_weights, selected_experts, tokens_per_expert + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + ori_shape = (batch_size, sequence_length, hidden_dim) + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + routing_weights, selected_experts, tokens_per_expert = self.topk_expert(router_logits) + + final_hidden_states = self.group_mlp( + hidden_states, + routing_weights.to(hidden_states.dtype), + selected_experts, + tokens_per_expert + ) + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + + final_hidden_states = final_hidden_states + shared_expert_output + final_hidden_states = final_hidden_states.view(ori_shape) + + return final_hidden_states, router_logits + +def apply_group_gemm_patch(model): + cnt = 0 + with torch.device('meta'): + dummy_moe_layer = Qwen3MoeSparseMoeBlock_Grouped(model.config) + num_experts = model.config.num_experts + size_0 = dummy_moe_layer.group_mlp.gate_proj.weight.shape[0] // num_experts + size_1 = dummy_moe_layer.group_mlp.down_proj.weight.shape[0] // num_experts + + def copy_expert_weights(i, moe_group_layer, layer, size_0, size_1): + start_idx_0 = i * size_0 + end_idx_0 = (i + 1) * size_0 + start_idx_1 = i * size_1 + end_idx_1 = (i + 1) * size_1 + + moe_group_layer.group_mlp.gate_proj.weight.data[start_idx_0:end_idx_0].copy_( + layer.mlp.experts[i].gate_proj.weight.data + ) + moe_group_layer.group_mlp.up_proj.weight.data[start_idx_0:end_idx_0].copy_( + layer.mlp.experts[i].up_proj.weight.data + ) + moe_group_layer.group_mlp.down_proj.weight.data[start_idx_1:end_idx_1].copy_( + layer.mlp.experts[i].down_proj.weight.data + ) + + for layer in model.model.layers: + cnt += 1 + if model.device.type == 'meta': + with torch.device('meta'): + moe_group_layer = Qwen3MoeSparseMoeBlock_Grouped(model.config).to(model.dtype) + else: + moe_group_layer = Qwen3MoeSparseMoeBlock_Grouped(model.config).to(model.dtype) + # Copy gate weight + moe_group_layer.gate.weight.data.copy_(layer.mlp.gate.weight.data) + # Copy shared expert weights + moe_group_layer.shared_expert.gate_proj.weight.data.copy_(layer.mlp.shared_expert.gate_proj.weight.data) + moe_group_layer.shared_expert.up_proj.weight.data.copy_(layer.mlp.shared_expert.up_proj.weight.data) + moe_group_layer.shared_expert.down_proj.weight.data.copy_(layer.mlp.shared_expert.down_proj.weight.data) + + moe_group_layer.shared_expert_gate.weight.data.copy_(layer.mlp.shared_expert_gate.weight.data) + + with ThreadPoolExecutor(max_workers=16) as executor: + futures = [ + executor.submit( + copy_expert_weights, + i, moe_group_layer, layer, size_0, size_1 + ) + for i in range(num_experts) + ] + for future in futures: + future.result() + old_mlp = layer.mlp + del old_mlp + layer.register_module("mlp", moe_group_layer) From cb0f78f1f23aac15ec5f40135f9594eb2b583684 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Tue, 16 Sep 2025 21:07:17 +0800 Subject: [PATCH 02/13] fix pylint --- chatlearn/models/patches/transformers/qwen3_next_moe_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py index 2995a75e..2b86d06a 100644 --- a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py +++ b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py @@ -136,6 +136,7 @@ def backward(ctx, grad_output): ) class Qwen3NextMLP(nn.Module): + """ Qwen3-Next MLP layer """ def __init__(self, config, intermediate_size=None): super().__init__() self.config = config From 1db278b1994a617aa6e2f2ce9c93857dc5c81d0e Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Wed, 17 Sep 2025 11:37:12 +0800 Subject: [PATCH 03/13] Move GroupMLP layer to layers --- .../patches/transformers/layers/groupgemm.py | 212 +++++++++++++++++ .../patches/transformers/qwen3_moe_patch.py | 213 +---------------- .../transformers/qwen3_next_moe_patch.py | 217 +----------------- 3 files changed, 219 insertions(+), 423 deletions(-) create mode 100644 chatlearn/models/patches/transformers/layers/groupgemm.py diff --git a/chatlearn/models/patches/transformers/layers/groupgemm.py b/chatlearn/models/patches/transformers/layers/groupgemm.py new file mode 100644 index 00000000..75b22821 --- /dev/null +++ b/chatlearn/models/patches/transformers/layers/groupgemm.py @@ -0,0 +1,212 @@ +"""groupgemm layer with transformer_engine ops""" +from transformers.activations import ACT2FN +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from transformer_engine.pytorch.cpp_extensions import grouped_gemm +except ImportError: + from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm as grouped_gemm +from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace +from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_unpermute, +) + +from chatlearn.utils import is_te_min_version + +class Linear(nn.Module): + """used for empty init gate_proj,up_proj,down_proj""" + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter( + torch.empty((out_features, in_features), **factory_kwargs) + ) + if bias: + self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) + else: + self.register_parameter("bias", None) + +class GroupGemm(torch.autograd.Function): + """ Autograd function for grouped gemm""" + @staticmethod + def forward( + ctx, + inp, + m_splits, + use_bias, + is_grad_enabled, + activation_dtype, + *weights_bias + ) -> torch.Tensor: + n_gemm = len(m_splits) + weights = weights_bias[:n_gemm] + bias = weights_bias[n_gemm:] + in_features = weights[0].shape[-1] + inputmats = torch.split(inp.view(-1, in_features), m_splits) + output_tensor = torch.empty( + [sum(m_splits), weights[0].shape[0]], + dtype=activation_dtype, + device=inputmats[0].device, + ) + grouped_gemm_kwargs = {'dtype': activation_dtype} + if is_te_min_version("2.0.0"): + grouped_gemm_kwargs = {'out_dtype': activation_dtype, 'm_splits': m_splits} + _ = grouped_gemm( + A=weights, + B=inputmats, + out=torch.split(output_tensor, m_splits), + workspaces=get_multi_stream_cublas_workspace(), + bias=bias, + use_bias=use_bias, + **grouped_gemm_kwargs + ) + if is_grad_enabled: + ctx.save_for_backward( + *inputmats, + *weights, + ) + ctx.m_splits = m_splits + ctx.num_gemm = n_gemm + ctx.activation_dtype = activation_dtype + ctx.use_bias = use_bias + ctx.inp_shape = inp.shape + return output_tensor.view(-1, *inp.shape[1:-1], output_tensor.shape[-1]) + + @staticmethod + def backward(ctx, grad_output): + with torch.cuda.nvtx.range("_GroupedLinear_backward"): + saved_tensors = ctx.saved_tensors + inputmats = saved_tensors[:ctx.num_gemm] + weights = saved_tensors[ctx.num_gemm:] + + grad_output = grad_output.contiguous() + grad_output_mats = torch.split( + grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits + ) + #dgrad + dgrad = torch.empty( + (sum(ctx.m_splits), weights[0].size(1)), + dtype=ctx.activation_dtype, + device=grad_output.device, + ) + grouped_gemm_kwargs = {'dtype': ctx.activation_dtype} + if is_te_min_version("2.0.0"): + grouped_gemm_kwargs = {'out_dtype': ctx.activation_dtype, 'm_splits': ctx.m_splits} + grouped_gemm( + A=weights, + B=grad_output_mats, + out=torch.split(dgrad, ctx.m_splits), + workspaces=get_multi_stream_cublas_workspace(), + layout="NN", + grad=True, + **grouped_gemm_kwargs + ) + + #wgrad + wgrad_list = [ + torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) + for w in weights + ] + _, grad_biases, _ = grouped_gemm( + A=inputmats, + B=grad_output_mats, + out=wgrad_list, + workspaces=get_multi_stream_cublas_workspace(), + layout="NT", + grad=True, + use_bias=ctx.use_bias, + **grouped_gemm_kwargs + ) + if not ctx.use_bias: + grad_biases = [None] + return ( + dgrad.view(ctx.inp_shape), + None, + None, + None, + None, + *wgrad_list, + *grad_biases, + ) + +def grouped_linear(inp, m_splits, use_bias, is_grad_enabled, activation_dtype, weights_bias, num_experts): + weights_bias = torch.chunk(weights_bias, num_experts, dim=0) + output = GroupGemm.apply( + inp, + m_splits, + use_bias, + is_grad_enabled, + activation_dtype, + *weights_bias + ) + return output + +class MoeGroupMLP(nn.Module): + """ Group MLP Layer """ + def __init__(self, config, intermediate_size): + super().__init__() + self.num_experts = config.num_experts + hidden_size = config.hidden_size + + self.gate_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) + self.up_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) + self.down_proj = Linear(intermediate_size, hidden_size * self.num_experts, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, router_weights, selected_experts, token_per_expert) -> torch.Tensor: + flat_seqlen = hidden_states.shape[0] + topk = router_weights.shape[1] + if is_te_min_version("2.1.0"): + grouped_input, row_id_map = moe_permute(hidden_states, selected_experts, map_type='index') + else: + grouped_input, row_id_map = moe_permute(hidden_states, selected_experts) + probs = router_weights.T.contiguous().view(-1, 1) + + token_per_expert = token_per_expert.tolist() + gate_output = self.act_fn(grouped_linear( + inp = grouped_input, + m_splits = token_per_expert, + use_bias = False, + is_grad_enabled = self.training, + activation_dtype = hidden_states.dtype, + weights_bias = self.gate_proj.weight, + num_experts=self.num_experts + )) + up_output = grouped_linear( + inp = grouped_input, + m_splits = token_per_expert, + use_bias = False, + is_grad_enabled = self.training, + activation_dtype = hidden_states.dtype, + weights_bias = self.up_proj.weight, + num_experts=self.num_experts + ) * gate_output + down_output = grouped_linear( + inp = up_output, + m_splits = token_per_expert, + use_bias = False, + is_grad_enabled = self.training, + activation_dtype = hidden_states.dtype, + weights_bias = self.down_proj.weight, + num_experts=self.num_experts + ) + if is_te_min_version("2.1.0"): + final_hidden_states = moe_unpermute(down_output, row_id_map, probs, map_type='index') + else: + final_hidden_states = moe_unpermute(down_output, row_id_map, probs) + final_hidden_states = final_hidden_states.view(topk, flat_seqlen, -1).permute(1,0,2) + final_hidden_states = torch.sum(final_hidden_states, dim=1).squeeze(1) + return final_hidden_states \ No newline at end of file diff --git a/chatlearn/models/patches/transformers/qwen3_moe_patch.py b/chatlearn/models/patches/transformers/qwen3_moe_patch.py index be2ddbf5..05129de3 100644 --- a/chatlearn/models/patches/transformers/qwen3_moe_patch.py +++ b/chatlearn/models/patches/transformers/qwen3_moe_patch.py @@ -16,219 +16,10 @@ from concurrent.futures import ThreadPoolExecutor import torch -from torch import nn +import torch.nn as nn import torch.nn.functional as F -try: - from transformer_engine.pytorch.cpp_extensions import grouped_gemm -except ImportError: - from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm as grouped_gemm -from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace -from transformer_engine.pytorch.permutation import ( - moe_permute, - moe_unpermute, -) -from transformers.activations import ACT2FN - -from chatlearn.utils import is_te_min_version - - -class GroupGemm(torch.autograd.Function): - """ Autograd function for grouped gemm""" - @staticmethod - def forward( - ctx, - inp, - m_splits, - use_bias, - is_grad_enabled, - activation_dtype, - *weights_bias - ) -> torch.Tensor: - n_gemm = len(m_splits) - weights = weights_bias[:n_gemm] - bias = weights_bias[n_gemm:] - in_features = weights[0].shape[-1] - inputmats = torch.split(inp.view(-1, in_features), m_splits) - output_tensor = torch.empty( - [sum(m_splits), weights[0].shape[0]], - dtype=activation_dtype, - device=inputmats[0].device, - ) - grouped_gemm_kwargs = {'dtype': activation_dtype} - if is_te_min_version("2.0.0"): - grouped_gemm_kwargs = {'out_dtype': activation_dtype, 'm_splits': m_splits} - _ = grouped_gemm( - A=weights, - B=inputmats, - out=torch.split(output_tensor, m_splits), - workspaces=get_multi_stream_cublas_workspace(), - bias=bias, - use_bias=use_bias, - **grouped_gemm_kwargs - ) - if is_grad_enabled: - ctx.save_for_backward( - *inputmats, - *weights, - ) - ctx.m_splits = m_splits - ctx.num_gemm = n_gemm - ctx.activation_dtype = activation_dtype - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - return output_tensor.view(-1, *inp.shape[1:-1], output_tensor.shape[-1]) - - @staticmethod - def backward(ctx, grad_output): - with torch.cuda.nvtx.range("_GroupedLinear_backward"): - saved_tensors = ctx.saved_tensors - inputmats = saved_tensors[:ctx.num_gemm] - weights = saved_tensors[ctx.num_gemm:] - - grad_output = grad_output.contiguous() - grad_output_mats = torch.split( - grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits - ) - #dgrad - dgrad = torch.empty( - (sum(ctx.m_splits), weights[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - grouped_gemm_kwargs = {'dtype': ctx.activation_dtype} - if is_te_min_version("2.0.0"): - grouped_gemm_kwargs = {'out_dtype': ctx.activation_dtype, 'm_splits': ctx.m_splits} - grouped_gemm( - A=weights, - B=grad_output_mats, - out=torch.split(dgrad, ctx.m_splits), - workspaces=get_multi_stream_cublas_workspace(), - layout="NN", - grad=True, - **grouped_gemm_kwargs - ) - - #wgrad - wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) - for w in weights - ] - _, grad_biases, _ = grouped_gemm( - A=inputmats, - B=grad_output_mats, - out=wgrad_list, - workspaces=get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - **grouped_gemm_kwargs - ) - if not ctx.use_bias: - grad_biases = [None] - return ( - dgrad.view(ctx.inp_shape), - None, - None, - None, - None, - *wgrad_list, - *grad_biases, - ) - -def grouped_linear(inp, m_splits, use_bias, is_grad_enabled, activation_dtype, weights_bias, num_experts): - weights_bias = torch.chunk(weights_bias, num_experts, dim=0) - output = GroupGemm.apply( - inp, - m_splits, - use_bias, - is_grad_enabled, - activation_dtype, - *weights_bias - ) - return output - - -class Linear(nn.Module): - """used for empty init gate_proj,up_proj,down_proj""" - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter( - torch.empty((out_features, in_features), **factory_kwargs) - ) - if bias: - self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter("bias", None) - - -class MoeGroupMLP(nn.Module): - """ Group MLP Layer """ - def __init__(self, config, intermediate_size): - super().__init__() - self.num_experts = config.num_experts - hidden_size = config.hidden_size - - self.gate_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) - self.up_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size * self.num_experts, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states, router_weights, selected_experts, token_per_expert) -> torch.Tensor: - flat_seqlen = hidden_states.shape[0] - topk = router_weights.shape[1] - if is_te_min_version("2.1.0"): - grouped_input, row_id_map = moe_permute(hidden_states, selected_experts, map_type='index') - else: - grouped_input, row_id_map = moe_permute(hidden_states, selected_experts) - probs = router_weights.T.contiguous().view(-1, 1) - - token_per_expert = token_per_expert.tolist() - gate_output = self.act_fn(grouped_linear( - inp = grouped_input, - m_splits = token_per_expert, - use_bias = False, - is_grad_enabled = self.training, - activation_dtype = hidden_states.dtype, - weights_bias = self.gate_proj.weight, - num_experts=self.num_experts - )) - up_output = grouped_linear( - inp = grouped_input, - m_splits = token_per_expert, - use_bias = False, - is_grad_enabled = self.training, - activation_dtype = hidden_states.dtype, - weights_bias = self.up_proj.weight, - num_experts=self.num_experts - ) * gate_output - down_output = grouped_linear( - inp = up_output, - m_splits = token_per_expert, - use_bias = False, - is_grad_enabled = self.training, - activation_dtype = hidden_states.dtype, - weights_bias = self.down_proj.weight, - num_experts=self.num_experts - ) - if is_te_min_version("2.1.0"): - final_hidden_states = moe_unpermute(down_output, row_id_map, probs, map_type='index') - else: - final_hidden_states = moe_unpermute(down_output, row_id_map, probs) - final_hidden_states = final_hidden_states.view(topk, flat_seqlen, -1).permute(1,0,2) - final_hidden_states = torch.sum(final_hidden_states, dim=1).squeeze(1) - return final_hidden_states +from chatlearn.models.patches.transformers.layers.groupgemm import MoeGroupMLP class Qwen3MoeSparseMoeBlock_Grouped(nn.Module): """ MOE Block support grouped linear """ diff --git a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py index 2b86d06a..d3aa97ae 100644 --- a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py +++ b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py @@ -12,128 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""patches for qwen3-moe model""" +"""patches for qwen3-next model""" from concurrent.futures import ThreadPoolExecutor import torch -from torch import nn +import torch.nn as nn import torch.nn.functional as F -try: - from transformer_engine.pytorch.cpp_extensions import grouped_gemm -except ImportError: - from transformer_engine.pytorch.cpp_extensions import general_grouped_gemm as grouped_gemm -from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace -from transformer_engine.pytorch.permutation import ( - moe_permute, - moe_unpermute, -) -from transformers.activations import ACT2FN - -from chatlearn.utils import is_te_min_version - -class GroupGemm(torch.autograd.Function): - """ Autograd function for grouped gemm""" - @staticmethod - def forward( - ctx, - inp, - m_splits, - use_bias, - is_grad_enabled, - activation_dtype, - *weights_bias - ) -> torch.Tensor: - n_gemm = len(m_splits) - weights = weights_bias[:n_gemm] - bias = weights_bias[n_gemm:] - in_features = weights[0].shape[-1] - inputmats = torch.split(inp.view(-1, in_features), m_splits) - output_tensor = torch.empty( - [sum(m_splits), weights[0].shape[0]], - dtype=activation_dtype, - device=inputmats[0].device, - ) - grouped_gemm_kwargs = {'dtype': activation_dtype} - if is_te_min_version("2.0.0"): - grouped_gemm_kwargs = {'out_dtype': activation_dtype, 'm_splits': m_splits} - _ = grouped_gemm( - A=weights, - B=inputmats, - out=torch.split(output_tensor, m_splits), - workspaces=get_multi_stream_cublas_workspace(), - bias=bias, - use_bias=use_bias, - **grouped_gemm_kwargs - ) - if is_grad_enabled: - ctx.save_for_backward( - *inputmats, - *weights, - ) - ctx.m_splits = m_splits - ctx.num_gemm = n_gemm - ctx.activation_dtype = activation_dtype - ctx.use_bias = use_bias - ctx.inp_shape = inp.shape - return output_tensor.view(-1, *inp.shape[1:-1], output_tensor.shape[-1]) - - @staticmethod - def backward(ctx, grad_output): - with torch.cuda.nvtx.range("_GroupedLinear_backward"): - saved_tensors = ctx.saved_tensors - inputmats = saved_tensors[:ctx.num_gemm] - weights = saved_tensors[ctx.num_gemm:] - - grad_output = grad_output.contiguous() - grad_output_mats = torch.split( - grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits - ) - #dgrad - dgrad = torch.empty( - (sum(ctx.m_splits), weights[0].size(1)), - dtype=ctx.activation_dtype, - device=grad_output.device, - ) - grouped_gemm_kwargs = {'dtype': ctx.activation_dtype} - if is_te_min_version("2.0.0"): - grouped_gemm_kwargs = {'out_dtype': ctx.activation_dtype, 'm_splits': ctx.m_splits} - grouped_gemm( - A=weights, - B=grad_output_mats, - out=torch.split(dgrad, ctx.m_splits), - workspaces=get_multi_stream_cublas_workspace(), - layout="NN", - grad=True, - **grouped_gemm_kwargs - ) - - #wgrad - wgrad_list = [ - torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device) - for w in weights - ] - _, grad_biases, _ = grouped_gemm( - A=inputmats, - B=grad_output_mats, - out=wgrad_list, - workspaces=get_multi_stream_cublas_workspace(), - layout="NT", - grad=True, - use_bias=ctx.use_bias, - **grouped_gemm_kwargs - ) - if not ctx.use_bias: - grad_biases = [None] - return ( - dgrad.view(ctx.inp_shape), - None, - None, - None, - None, - *wgrad_list, - *grad_biases, - ) +from chatlearn.models.patches.transformers.layers.groupgemm import MoeGroupMLP class Qwen3NextMLP(nn.Module): """ Qwen3-Next MLP layer """ @@ -151,100 +37,6 @@ def forward(self, x): down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return down_proj -def grouped_linear(inp, m_splits, use_bias, is_grad_enabled, activation_dtype, weights_bias, num_experts): - weights_bias = torch.chunk(weights_bias, num_experts, dim=0) - output = GroupGemm.apply( - inp, - m_splits, - use_bias, - is_grad_enabled, - activation_dtype, - *weights_bias - ) - return output - - -class Linear(nn.Module): - """used for empty init gate_proj,up_proj,down_proj""" - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.weight = nn.Parameter( - torch.empty((out_features, in_features), **factory_kwargs) - ) - if bias: - self.bias = Parameter(torch.empty(out_features, **factory_kwargs)) - else: - self.register_parameter("bias", None) - - -class MoeGroupMLP(nn.Module): - """ Group MLP Layer """ - def __init__(self, config, intermediate_size): - super().__init__() - self.num_experts = config.num_experts - hidden_size = config.hidden_size - - self.gate_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) - self.up_proj = Linear(hidden_size, intermediate_size * self.num_experts, bias=False) - self.down_proj = Linear(intermediate_size, hidden_size * self.num_experts, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states, router_weights, selected_experts, token_per_expert) -> torch.Tensor: - flat_seqlen = hidden_states.shape[0] - topk = router_weights.shape[1] - if is_te_min_version("2.1.0"): - grouped_input, row_id_map = moe_permute(hidden_states, selected_experts, map_type='index') - else: - grouped_input, row_id_map = moe_permute(hidden_states, selected_experts) - probs = router_weights.T.contiguous().view(-1, 1) - - token_per_expert = token_per_expert.tolist() - gate_output = self.act_fn(grouped_linear( - inp = grouped_input, - m_splits = token_per_expert, - use_bias = False, - is_grad_enabled = self.training, - activation_dtype = hidden_states.dtype, - weights_bias = self.gate_proj.weight, - num_experts=self.num_experts - )) - up_output = grouped_linear( - inp = grouped_input, - m_splits = token_per_expert, - use_bias = False, - is_grad_enabled = self.training, - activation_dtype = hidden_states.dtype, - weights_bias = self.up_proj.weight, - num_experts=self.num_experts - ) * gate_output - down_output = grouped_linear( - inp = up_output, - m_splits = token_per_expert, - use_bias = False, - is_grad_enabled = self.training, - activation_dtype = hidden_states.dtype, - weights_bias = self.down_proj.weight, - num_experts=self.num_experts - ) - if is_te_min_version("2.1.0"): - final_hidden_states = moe_unpermute(down_output, row_id_map, probs, map_type='index') - else: - final_hidden_states = moe_unpermute(down_output, row_id_map, probs) - final_hidden_states = final_hidden_states.view(topk, flat_seqlen, -1).permute(1,0,2) - final_hidden_states = torch.sum(final_hidden_states, dim=1).squeeze(1) - return final_hidden_states - class Qwen3MoeSparseMoeBlock_Grouped(nn.Module): """ MOE Block support grouped linear """ def __init__(self, config): @@ -324,13 +116,14 @@ def copy_expert_weights(i, moe_group_layer, layer, size_0, size_1): moe_group_layer = Qwen3MoeSparseMoeBlock_Grouped(model.config).to(model.dtype) # Copy gate weight moe_group_layer.gate.weight.data.copy_(layer.mlp.gate.weight.data) + # Copy shared expert weights moe_group_layer.shared_expert.gate_proj.weight.data.copy_(layer.mlp.shared_expert.gate_proj.weight.data) moe_group_layer.shared_expert.up_proj.weight.data.copy_(layer.mlp.shared_expert.up_proj.weight.data) moe_group_layer.shared_expert.down_proj.weight.data.copy_(layer.mlp.shared_expert.down_proj.weight.data) - moe_group_layer.shared_expert_gate.weight.data.copy_(layer.mlp.shared_expert_gate.weight.data) + # Copy other expert weights with ThreadPoolExecutor(max_workers=16) as executor: futures = [ executor.submit( From 0a14a4a72f51aadbc546bb979dec32839a98ce59 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Wed, 17 Sep 2025 11:41:40 +0800 Subject: [PATCH 04/13] fix pylint --- chatlearn/models/patches/transformers/layers/groupgemm.py | 5 ++--- chatlearn/models/patches/transformers/qwen3_moe_patch.py | 2 +- .../models/patches/transformers/qwen3_next_moe_patch.py | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/chatlearn/models/patches/transformers/layers/groupgemm.py b/chatlearn/models/patches/transformers/layers/groupgemm.py index 75b22821..1c2a5eef 100644 --- a/chatlearn/models/patches/transformers/layers/groupgemm.py +++ b/chatlearn/models/patches/transformers/layers/groupgemm.py @@ -1,8 +1,7 @@ """groupgemm layer with transformer_engine ops""" from transformers.activations import ACT2FN import torch -import torch.nn as nn -import torch.nn.functional as F +from torch import nn try: from transformer_engine.pytorch.cpp_extensions import grouped_gemm @@ -209,4 +208,4 @@ def forward(self, hidden_states, router_weights, selected_experts, token_per_exp final_hidden_states = moe_unpermute(down_output, row_id_map, probs) final_hidden_states = final_hidden_states.view(topk, flat_seqlen, -1).permute(1,0,2) final_hidden_states = torch.sum(final_hidden_states, dim=1).squeeze(1) - return final_hidden_states \ No newline at end of file + return final_hidden_states diff --git a/chatlearn/models/patches/transformers/qwen3_moe_patch.py b/chatlearn/models/patches/transformers/qwen3_moe_patch.py index 05129de3..c0657a3d 100644 --- a/chatlearn/models/patches/transformers/qwen3_moe_patch.py +++ b/chatlearn/models/patches/transformers/qwen3_moe_patch.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor import torch -import torch.nn as nn +from torch import nn import torch.nn.functional as F from chatlearn.models.patches.transformers.layers.groupgemm import MoeGroupMLP diff --git a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py index d3aa97ae..11063a3f 100644 --- a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py +++ b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor import torch -import torch.nn as nn +from torch import nn import torch.nn.functional as F from chatlearn.models.patches.transformers.layers.groupgemm import MoeGroupMLP From bd6ad35b9340484cd48e7b0bd524cdb9be189f05 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Thu, 18 Sep 2025 15:48:38 +0800 Subject: [PATCH 05/13] speed up load --- chatlearn/models/fsdp_module.py | 161 ++++++++++++++++++++++++++------ 1 file changed, 132 insertions(+), 29 deletions(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index c8054216..44756458 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -18,11 +18,14 @@ import random import gc from typing import List +import glob +import json +from safetensors.torch import load_file import numpy as np import torch import torch.distributed as dist -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, distribute_tensor from torch import optim, nn from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict, get_model_state_dict @@ -95,6 +98,99 @@ def fsdp2_clip_grad_norm_(self, parameters, max_norm, norm_type=2.0, error_if_no return total_norm + def split_list(self, lst, n): + """Split list into n roughly equal chunks.""" + k, m = divmod(len(lst), n) + return [lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)] + + def get_dtensor(self, model, hf_dir): + mapping = { + 'gate_weight': 'gate_proj.weight', + 'up_weight': 'up_proj.weight', + 'down_weight': 'down_proj.weight' + } + world_size = dist.get_world_size() + local_rank = dist.get_rank() + safetensor_files = glob.glob(os.path.join(hf_dir, "*.safetensors")) + safetensor_files = self.split_list(safetensor_files, world_size) + local_safetensor_file = safetensor_files[local_rank] + local_tensors = {} + for file in local_safetensor_file: + local_tensors.update(load_file(file, device="cuda")) + + weight_map = json.load(open(os.path.join(hf_dir, "model.safetensors.index.json")))['weight_map'] + meta_sharded_sd = model.state_dict() + sharded_sd = {} + shape_list=[] + for param_name, param in meta_sharded_sd.items(): + if meta_sharded_sd[param_name].shape not in shape_list: + shape_list.append(meta_sharded_sd[param_name].shape) + tensor_buffer = {} + for shape in shape_list: + tensor_buffer[shape] = torch.empty(shape, dtype=torch.bfloat16, device="cuda") + print(tensor_buffer.keys()) + for param_name in meta_sharded_sd.keys(): + # print(param_name) + sharded_meta_param = meta_sharded_sd.get(param_name) + shape_key = meta_sharded_sd[param_name].shape + if param_name.split('.')[-1] in mapping: + sequential_mlp_name_list = [] + for i in range(512): + part = param_name.split('.')[-1] + sequential_mlp_name_list.append(name.replace('group_mlp', f"experts.{i}").replace(part, mapping[part])) + single_expert_shape = (sharded_meta_param.shape[0] // 512, sharded_meta_param.shape[1]) + local_tensor = torch.empty(single_expert_shape, dtype=torch.bfloat16, device='cuda') + group_gemm_tensor = tensor_buffer[shape_key] + for idx, single_mlp_name in enumerate(sequential_mlp_name_list): + if single_mlp_name in local_tensors: + local_tensor.copy_(local_tensors.pop(single_mlp_name)) + #sharded_sd[single_mlp_name] = local_tensor + safe_tensor_file = weight_map[single_mlp_name] + rank_has_data=None + for i, file_list in enumerate(safetensor_files): + for file in file_list: + if safe_tensor_file in file: + rank_has_data = i + break + dist.broadcast(local_tensor, src=rank_has_data) + group_gemm_tensor[idx * single_expert_shape[0]: (idx + 1) * single_expert_shape[0], :].copy_(local_tensor) + local_tensor = torch.chunk(tensor_buffer[shape_key],world_size, dim=0)[local_rank] + sharded_tensor = DTensor.from_local(local_tensor.clone(), sharded_meta_param.device_mesh, sharded_meta_param.placements) + sharded_sd[param_name] = nn.Parameter(sharded_tensor) + else: + if param_name in local_tensors: + tensor_buffer[shape_key].copy_(local_tensors.pop(param_name)) + # else: + # full_tensor = torch.empty(meta_sharded_sd[param_name].shape, dtype=torch.bfloat16, device=local_rank) + # print(full_tensor.shape) + safe_tensor_file = weight_map[param_name] + rank_has_data=None + for i, file_list in enumerate(safetensor_files): + for file in file_list: + if safe_tensor_file in file: + rank_has_data = i + break + dist.broadcast(tensor_buffer[shape_key], src=rank_has_data) + if False: + print(f"{local_rank}: {sharded_meta_param}") + # no shard + if local_rank==0: + sharded_sd[param_name] = DTensor.from_local(tensor_buffer[shape_key].clone(), sharded_meta_param.device_mesh, sharded_meta_param.placements) + else: + sharded_sd[param_name] = DTensor.from_local(torch.empty((0, sharded_meta_param.size()[1]), dtype=torch.bfloat16, device='cuda'), sharded_meta_param.device_mesh, sharded_meta_param.placements) + print(f"{local_rank}: {sharded_sd[param_name]}") + else: + #local_tensor = torch.chunk(tensor_buffer[shape_key], world_size, dim=0)[local_rank] + sharded_tensor = distribute_tensor( + tensor_buffer[shape_key].clone(), + sharded_meta_param.device_mesh, + sharded_meta_param.placements, + ) + #DTensor.from_local(local_tensor.clone(), sharded_meta_param.device_mesh, sharded_meta_param.placements) + sharded_sd[param_name] = nn.Parameter(sharded_tensor) + dist.barrier() + return sharded_sd + def create_device_mesh(self, world_size, fsdp_size): if not self.device_mesh: if world_size == fsdp_size: @@ -234,7 +330,7 @@ def model_setup(self): local_rank = dist.get_rank() # When meta_init is enabled, we only load checkpoint on rank 0 - meta_init = self.module_args.meta_init and local_rank != 0 + meta_init = self.module_args.meta_init model = self.create_model(args.load, torch_dtype=torch.bfloat16, meta_init=meta_init) self.model_config = model.config if self.module_args.groupgemm: @@ -276,34 +372,42 @@ def model_setup(self): for module in modules: fully_shard(module, **fsdp_kwargs) fully_shard(model, **fsdp_kwargs) - if self.module_args.meta_init: - # save buffer data - buffer_dict = {} - for name, buf in model.named_buffers(): - buffer_dict[name] = buf - model.to_empty(device="cuda") - # load real state dict - options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True) - - # module-wise sync avoid OOM while run model like qwen3-moe-235B + if self.module_args.meta_init: + shard_dict = self.get_dtensor(model, args.load) + model.load_state_dict(shard_dict, assign=True) for name, module in model.named_modules(): - has_weights = any(k.startswith(name + ".") for k in full_state.keys()) and len(list(module.children()))==0 - if has_weights: - set_model_state_dict( - module, - {k.replace(name + ".", ""): v for k, v in full_state.items() if k.startswith(name + ".")}, - options=options - ) - # set_model_state_dict(model, full_state, options=options) - - # load buffer data - if dist.get_rank()==0: - for name, buf in model.named_buffers(): - buf.data.copy_(buffer_dict[name]) - torch.cuda.synchronize() - for name, buf in model.named_buffers(): - dist.broadcast(buf, src=0) + if "rotary_emb" in name: + module.__init__(model.config) + + # if self.module_args.meta_init: + # # save buffer data + # buffer_dict = {} + # for name, buf in model.named_buffers(): + # buffer_dict[name] = buf + # model.to_empty(device="cuda") + + # # load real state dict + # options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True) + + # # module-wise sync avoid OOM while run model like qwen3-moe-235B + # for name, module in model.named_modules(): + # has_weights = any(k.startswith(name + ".") for k in full_state.keys()) and len(list(module.children()))==0 + # if has_weights: + # set_model_state_dict( + # module, + # {k.replace(name + ".", ""): v for k, v in full_state.items() if k.startswith(name + ".")}, + # options=options + # ) + # # set_model_state_dict(model, full_state, options=options) + + # # load buffer data + # if dist.get_rank()==0: + # for name, buf in model.named_buffers(): + # buf.data.copy_(buffer_dict[name]) + # torch.cuda.synchronize() + # for name, buf in model.named_buffers(): + # dist.broadcast(buf, src=0) self.model = model self.model.to(torch.float32) @@ -321,7 +425,6 @@ def model_setup(self): # resume model weights if self.resume_training: self.load_checkpoint(self._episode_id) - del full_state self.offload() def get_fsdp_param_name(self, block_size=300_000_000) -> List[List]: From 0d9906b029ad79fe0051b12199cc70b96b17d609 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Thu, 18 Sep 2025 19:42:31 +0800 Subject: [PATCH 06/13] fix expert --- chatlearn/models/fsdp_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index aeb229f6..f06a5e65 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -135,10 +135,10 @@ def get_dtensor(self, model, hf_dir): if 'group_mlp' in param_name: sequential_mlp_name_list = [] num_expert = model.config.num_experts - for i in range(512): + for i in range(num_expert): part = param_name.split('.')[-1] sequential_mlp_name_list.append(param_name.replace('group_mlp', f"experts.{i}")) - single_expert_shape = (sharded_meta_param.shape[0] // 512, sharded_meta_param.shape[1]) + single_expert_shape = (sharded_meta_param.shape[0] // num_expert, sharded_meta_param.shape[1]) local_tensor = torch.empty(single_expert_shape, dtype=torch.bfloat16, device='cuda') group_gemm_tensor = tensor_buffer[shape_key] for idx, single_mlp_name in enumerate(sequential_mlp_name_list): From f817adfc21955397c4d822a51f26a25820b04bf4 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 22 Sep 2025 13:43:31 +0800 Subject: [PATCH 07/13] Fix groupgemm dtype; Rerewrite FSDP acceleration --- chatlearn/models/fsdp_module.py | 186 ++++++++---------- .../patches/transformers/layers/groupgemm.py | 6 +- .../transformers/qwen3_next_moe_patch.py | 1 + chatlearn/models/torch_module.py | 2 +- 4 files changed, 83 insertions(+), 112 deletions(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index f06a5e65..867ad99a 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -20,6 +20,9 @@ from typing import List, Dict import glob import json +import math + +from safetensors import safe_open from safetensors.torch import load_file import numpy as np @@ -34,6 +37,7 @@ from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForImageTextToText +from accelerate import init_on_device from chatlearn.utils.logger import debug_rank_0 from chatlearn.utils.utils import dict_to_simplenamespace @@ -104,14 +108,17 @@ def split_list(self, lst, n): k, m = divmod(len(lst), n) return [lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)] - def get_dtensor(self, model, hf_dir): - mapping = { - 'gate_weight': 'gate_proj.weight', - 'up_weight': 'up_proj.weight', - 'down_weight': 'down_proj.weight' - } + def get_dtensor(self, model, hf_dir, use_groupgemm): + """ + Accelerate loading huggingface checkpoints. + Split safetensor files to difference ranks and load them into GPU. + Each rank owns a bucket, when tensor is on local_rank, copy into bucket. + When bucket is full, allreduce bucket and update sharded_sd. + """ world_size = dist.get_world_size() local_rank = dist.get_rank() + + # Split safetensor files to difference ranks and load them into GPU safetensor_files = glob.glob(os.path.join(hf_dir, "*.safetensors")) safetensor_files = self.split_list(safetensor_files, world_size) local_safetensor_file = safetensor_files[local_rank] @@ -119,72 +126,70 @@ def get_dtensor(self, model, hf_dir): for file in local_safetensor_file: local_tensors.update(load_file(file, device="cuda")) - weight_map = json.load(open(os.path.join(hf_dir, "model.safetensors.index.json")))['weight_map'] + # Create bucket for all_reduce meta_sharded_sd = model.state_dict() - sharded_sd = {} - shape_list=[] - for param_name, param in meta_sharded_sd.items(): - if meta_sharded_sd[param_name].shape not in shape_list: - shape_list.append(meta_sharded_sd[param_name].shape) - tensor_buffer = {} - for shape in shape_list: - tensor_buffer[shape] = torch.empty(shape, dtype=torch.bfloat16, device="cuda") - for param_name in meta_sharded_sd.keys(): - sharded_meta_param = meta_sharded_sd.get(param_name) - shape_key = meta_sharded_sd[param_name].shape - if 'group_mlp' in param_name: - sequential_mlp_name_list = [] - num_expert = model.config.num_experts - for i in range(num_expert): - part = param_name.split('.')[-1] - sequential_mlp_name_list.append(param_name.replace('group_mlp', f"experts.{i}")) - single_expert_shape = (sharded_meta_param.shape[0] // num_expert, sharded_meta_param.shape[1]) - local_tensor = torch.empty(single_expert_shape, dtype=torch.bfloat16, device='cuda') - group_gemm_tensor = tensor_buffer[shape_key] - for idx, single_mlp_name in enumerate(sequential_mlp_name_list): - if single_mlp_name in local_tensors: - local_tensor.copy_(local_tensors.pop(single_mlp_name)) - #sharded_sd[single_mlp_name] = local_tensor - safe_tensor_file = weight_map[single_mlp_name] - rank_has_data=None - for i, file_list in enumerate(safetensor_files): - for file in file_list: - if safe_tensor_file in file: - rank_has_data = i - break - dist.broadcast(local_tensor, src=rank_has_data) - group_gemm_tensor[idx * single_expert_shape[0]: (idx + 1) * single_expert_shape[0], :].copy_(local_tensor) - sharded_tensor = distribute_tensor( - group_gemm_tensor.clone(), - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - sharded_sd[param_name] = nn.Parameter(sharded_tensor) + bucket_size = 3 * 1024 ** 3 + for _, meta_param in meta_sharded_sd.items(): + bucket_size = max(bucket_size, math.prod(meta_param.shape)) + bucket = torch.zeros(bucket_size, dtype=torch.bfloat16, device="cuda") + + # Bucketizing synchronize params. + # Since all_reduce is used, bucket is zero initialized. + shard_sd = {} + buffer_offset = 0 + param_to_sync = [] + for param_name, meta_param in meta_sharded_sd.items(): + if buffer_offset + math.prod(meta_param.shape) > bucket_size: + dist.all_reduce(bucket) + get_offset = 0 + for param_to_update in param_to_sync: + # Update sharded_sd + meta_info = meta_sharded_sd[param_to_update] + num_params = math.prod(meta_info.shape) + shard_sd[param_to_update] = distribute_tensor( + bucket[get_offset:get_offset + num_params].view(meta_info.shape).clone(), + meta_info.device_mesh, + meta_info.placements, + ) + get_offset += num_params + param_to_sync = [] + buffer_offset = 0 + bucket.fill_(0.0) + if "group_mlp" in param_name: + # If groupgemm is enabled, weights of each expert will be load one by one + num_experts = model.config.num_experts + local_offset = buffer_offset + num_param_per_expert = math.prod(meta_param.shape) // num_experts + for i in range(num_experts): + local_name = param_name.replace('group_mlp', f"experts.{i}") + if local_name in local_tensors: + local_tensor = local_tensors.pop(local_name) + bucket[local_offset: local_offset + num_param_per_expert].copy_(local_tensor.to(torch.bfloat16).view(-1)) + local_offset += num_param_per_expert else: if param_name in local_tensors: - tensor_buffer[shape_key].copy_(local_tensors.pop(param_name)) - # else: - # full_tensor = torch.empty(meta_sharded_sd[param_name].shape, dtype=torch.bfloat16, device=local_rank) - # print(full_tensor.shape) - safe_tensor_file = weight_map[param_name] - rank_has_data=None - for i, file_list in enumerate(safetensor_files): - for file in file_list: - if safe_tensor_file in file: - rank_has_data = i - break - dist.broadcast(tensor_buffer[shape_key], src=rank_has_data) - #local_tensor = torch.chunk(tensor_buffer[shape_key], world_size, dim=0)[local_rank] - sharded_tensor = distribute_tensor( - tensor_buffer[shape_key].clone(), - sharded_meta_param.device_mesh, - sharded_meta_param.placements, - ) - #DTensor.from_local(local_tensor.clone(), sharded_meta_param.device_mesh, sharded_meta_param.placements) - sharded_sd[param_name] = nn.Parameter(sharded_tensor) + local_tensor = local_tensors.pop(param_name) + tensor_numel = local_tensor.numel() + bucket[buffer_offset: buffer_offset + tensor_numel].copy_(local_tensor.to(torch.bfloat16).view(-1)) + buffer_offset += math.prod(meta_param.shape) + param_to_sync.append(param_name) dist.barrier() - del tensor_buffer - return sharded_sd + # Synchronize last bucket + dist.all_reduce(bucket) + get_offset = 0 + for param_to_update in param_to_sync: + meta_info = meta_sharded_sd[param_to_update] + num_params = math.prod(meta_info.shape) + # Update sharded_sd + shard_sd[param_to_update] = distribute_tensor( + bucket[get_offset:get_offset + num_params].view(meta_info.shape).clone(), + meta_info.device_mesh, + meta_info.placements, + ) + get_offset += num_params + dist.barrier() + del bucket + return shard_sd def create_device_mesh(self, world_size, fsdp_size): if not self.device_mesh: @@ -272,7 +277,7 @@ def create_model(self, model_path: str , torch_dtype: torch.dtype, meta_init: bo else: model_config = AutoConfig.from_pretrained(model_path) assert "Qwen2_5_VLForConditionalGeneration" not in model_config.architectures, "VL model not support meta init" - with torch.device('meta'): + with init_on_device('meta', include_buffers=False): model = AutoModelForCausalLM.from_config( model_config, torch_dtype=torch_dtype, @@ -324,7 +329,7 @@ def model_setup(self): self.args = args local_rank = dist.get_rank() - # When meta_init is enabled, we only load checkpoint on rank 0 + # When meta_init is enabled, we don't load ckpt here meta_init = self.module_args.meta_init model = self.create_model(args.load, torch_dtype=torch.bfloat16, meta_init=meta_init) self.model_config = model.config @@ -342,11 +347,6 @@ def model_setup(self): ) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) - # get state_dict to init model for meta init - full_state = None - if self.module_args.meta_init: - full_state = model.state_dict() - # fsdp2 warp mix_precision_config = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) fsdp_kwargs = { @@ -369,40 +369,10 @@ def model_setup(self): fully_shard(model, **fsdp_kwargs) if self.module_args.meta_init: - shard_dict = self.get_dtensor(model, args.load) + shard_dict = self.get_dtensor(model, args.load, self.module_args.groupgemm) model.load_state_dict(shard_dict, assign=True) - for name, module in model.named_modules(): - if "rotary_emb" in name: - module.__init__(model.config) - - # if self.module_args.meta_init: - # # save buffer data - # buffer_dict = {} - # for name, buf in model.named_buffers(): - # buffer_dict[name] = buf - # model.to_empty(device="cuda") - - # # load real state dict - # options = StateDictOptions(full_state_dict=True, cpu_offload=False, broadcast_from_rank0=True) - - # # module-wise sync avoid OOM while run model like qwen3-moe-235B - # for name, module in model.named_modules(): - # has_weights = any(k.startswith(name + ".") for k in full_state.keys()) and len(list(module.children()))==0 - # if has_weights: - # set_model_state_dict( - # module, - # {k.replace(name + ".", ""): v for k, v in full_state.items() if k.startswith(name + ".")}, - # options=options - # ) - # # set_model_state_dict(model, full_state, options=options) - - # # load buffer data - # if dist.get_rank()==0: - # for name, buf in model.named_buffers(): - # buf.data.copy_(buffer_dict[name]) - # torch.cuda.synchronize() - # for name, buf in model.named_buffers(): - # dist.broadcast(buf, src=0) + del shard_dict + self.model = model self.model.to(torch.float32) diff --git a/chatlearn/models/patches/transformers/layers/groupgemm.py b/chatlearn/models/patches/transformers/layers/groupgemm.py index 1c2a5eef..83bed6e9 100644 --- a/chatlearn/models/patches/transformers/layers/groupgemm.py +++ b/chatlearn/models/patches/transformers/layers/groupgemm.py @@ -169,10 +169,10 @@ def forward(self, hidden_states, router_weights, selected_experts, token_per_exp flat_seqlen = hidden_states.shape[0] topk = router_weights.shape[1] if is_te_min_version("2.1.0"): - grouped_input, row_id_map = moe_permute(hidden_states, selected_experts, map_type='index') + grouped_input, row_id_map = moe_permute(hidden_states, selected_experts.to(torch.int32), map_type='index') else: - grouped_input, row_id_map = moe_permute(hidden_states, selected_experts) - probs = router_weights.T.contiguous().view(-1, 1) + grouped_input, row_id_map = moe_permute(hidden_states, selected_experts.to(torch.int32)) + probs = router_weights.T.contiguous().view(-1, 1).to(torch.float32) token_per_expert = token_per_expert.tolist() gate_output = self.act_fn(grouped_linear( diff --git a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py index 11063a3f..5633e92e 100644 --- a/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py +++ b/chatlearn/models/patches/transformers/qwen3_next_moe_patch.py @@ -18,6 +18,7 @@ import torch from torch import nn import torch.nn.functional as F +from transformers.activations import ACT2FN from chatlearn.models.patches.transformers.layers.groupgemm import MoeGroupMLP diff --git a/chatlearn/models/torch_module.py b/chatlearn/models/torch_module.py index 60e79335..609f4249 100644 --- a/chatlearn/models/torch_module.py +++ b/chatlearn/models/torch_module.py @@ -187,8 +187,8 @@ def offload(self, self.offload_weights() torch.distributed.barrier() torch.cuda.synchronize() - gc.collect() torch.cuda.empty_cache() + gc.collect() torch.cuda.reset_peak_memory_stats() timer.stop() log_rank_0(get_full_proc_memory_info('After offload'), self._logger) From f0e89425de82560d86864848f6c1d49c68a57607 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 22 Sep 2025 13:51:15 +0800 Subject: [PATCH 08/13] fix pylint --- chatlearn/models/fsdp_module.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index d9e4e229..4de7a411 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -19,12 +19,8 @@ import gc from typing import List, Dict import glob -import json import math -from safetensors import safe_open -from safetensors.torch import load_file - import numpy as np import torch from torch import Tensor @@ -32,12 +28,13 @@ from torch.distributed.tensor import DTensor, distribute_tensor from torch import optim, nn from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard -from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict, get_model_state_dict +from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict from torch.multiprocessing.reductions import reduce_tensor from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForImageTextToText from accelerate import init_on_device +from safetensors.torch import load_file from chatlearn.utils.logger import debug_rank_0 from chatlearn.utils.utils import dict_to_simplenamespace @@ -328,7 +325,6 @@ def model_setup(self): args = dict_to_simplenamespace(self.module_args) self.args = args - local_rank = dist.get_rank() # When meta_init is enabled, we don't load ckpt here meta_init = self.module_args.meta_init model = self.create_model(args.load, torch_dtype=torch.bfloat16, meta_init=meta_init) @@ -347,12 +343,6 @@ def model_setup(self): ) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) - # get state_dict to init model for meta init - full_state = None - update_bucket = None - if self.module_args.meta_init: - full_state = model.state_dict() - # fsdp2 warp mix_precision_config = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32, cast_forward_inputs=True) fsdp_kwargs = { @@ -365,7 +355,7 @@ def model_setup(self): if isinstance(fsdp_transformer_layer_cls_to_wrap, str): fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] modules = [] - for name, module in model.named_modules(): + for _, module in model.named_modules(): if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or \ (isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings): modules.append(module) From 6d39421f429abfc327cbb456029b261d47edfe1f Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 22 Sep 2025 14:00:35 +0800 Subject: [PATCH 09/13] fix merge --- chatlearn/models/fsdp_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 4de7a411..1c20e34b 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -388,7 +388,7 @@ def model_setup(self): self.load_checkpoint(self._episode_id) self.offload() - def get_fsdp_param_name(self, block_size=300_000_000) -> List[List]: + def get_fsdp_param_name(self, block_size=3_000_000_000) -> List[List]: name_list = [] param_cnt = 0 current_group = [] From e08f621c704967aea928b29cfe48d25494184697 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 22 Sep 2025 15:21:45 +0800 Subject: [PATCH 10/13] fix comments --- chatlearn/models/fsdp_module.py | 71 ++++++++++++++++----------------- 1 file changed, 34 insertions(+), 37 deletions(-) diff --git a/chatlearn/models/fsdp_module.py b/chatlearn/models/fsdp_module.py index 1c20e34b..59eef47b 100644 --- a/chatlearn/models/fsdp_module.py +++ b/chatlearn/models/fsdp_module.py @@ -37,7 +37,7 @@ from safetensors.torch import load_file from chatlearn.utils.logger import debug_rank_0 -from chatlearn.utils.utils import dict_to_simplenamespace +from chatlearn.utils.utils import dict_to_simplenamespace, even_slice from chatlearn.utils.communication_op import set_sp_parallel_group from chatlearn.models.patches.monkey_patch import apply_sp_monkey_patch, apply_group_gemm from chatlearn.runtime.decorator import timeit, monitor_error @@ -100,12 +100,7 @@ def fsdp2_clip_grad_norm_(self, parameters, max_norm, norm_type=2.0, error_if_no return total_norm - def split_list(self, lst, n): - """Split list into n roughly equal chunks.""" - k, m = divmod(len(lst), n) - return [lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)] - - def get_dtensor(self, model, hf_dir, use_groupgemm): + def get_dtensor(self, model, hf_dir): """ Accelerate loading huggingface checkpoints. Split safetensor files to difference ranks and load them into GPU. @@ -117,11 +112,10 @@ def get_dtensor(self, model, hf_dir, use_groupgemm): # Split safetensor files to difference ranks and load them into GPU safetensor_files = glob.glob(os.path.join(hf_dir, "*.safetensors")) - safetensor_files = self.split_list(safetensor_files, world_size) - local_safetensor_file = safetensor_files[local_rank] + slice_index = even_slice(len(safetensor_files), world_size) local_tensors = {} - for file in local_safetensor_file: - local_tensors.update(load_file(file, device="cuda")) + for file_index in range(slice_index[local_rank], slice_index[local_rank + 1]): + local_tensors.update(load_file(safetensor_files[file_index], device="cuda")) # Create bucket for all_reduce meta_sharded_sd = model.state_dict() @@ -135,25 +129,39 @@ def get_dtensor(self, model, hf_dir, use_groupgemm): shard_sd = {} buffer_offset = 0 param_to_sync = [] + + def update_sharded_sd(bucket, param_to_sync, meta_sharded_sd): + """ + Create sharded_state_dict for params in bucket. + """ + dist.all_reduce(bucket) + get_offset = 0 + return_dict = {} + for param_to_update in param_to_sync: + # Update sharded_sd + meta_info = meta_sharded_sd[param_to_update] + num_params = math.prod(meta_info.shape) + return_dict[param_to_update] = distribute_tensor( + bucket[get_offset:get_offset + num_params].view(meta_info.shape).clone(), + meta_info.device_mesh, + meta_info.placements, + ) + get_offset += num_params + return return_dict + for param_name, meta_param in meta_sharded_sd.items(): if buffer_offset + math.prod(meta_param.shape) > bucket_size: - dist.all_reduce(bucket) - get_offset = 0 - for param_to_update in param_to_sync: - # Update sharded_sd - meta_info = meta_sharded_sd[param_to_update] - num_params = math.prod(meta_info.shape) - shard_sd[param_to_update] = distribute_tensor( - bucket[get_offset:get_offset + num_params].view(meta_info.shape).clone(), - meta_info.device_mesh, - meta_info.placements, - ) - get_offset += num_params + shard_sd.update(update_sharded_sd(bucket, param_to_sync, meta_sharded_sd)) param_to_sync = [] buffer_offset = 0 bucket.fill_(0.0) + # TODO: now weight is forced to bfloat16, try to fix mix-precision hf ckpt if "group_mlp" in param_name: # If groupgemm is enabled, weights of each expert will be load one by one + # Before all_reduce: + # rank0: [expert0, 0, expert2]; rank1: [0, expert1, 0] + # After all_reduce: + # rank0: [expert0, expert1, expert2]; rank1: [expert0, expert1, expert2] num_experts = model.config.num_experts local_offset = buffer_offset num_param_per_expert = math.prod(meta_param.shape) // num_experts @@ -172,18 +180,7 @@ def get_dtensor(self, model, hf_dir, use_groupgemm): param_to_sync.append(param_name) dist.barrier() # Synchronize last bucket - dist.all_reduce(bucket) - get_offset = 0 - for param_to_update in param_to_sync: - meta_info = meta_sharded_sd[param_to_update] - num_params = math.prod(meta_info.shape) - # Update sharded_sd - shard_sd[param_to_update] = distribute_tensor( - bucket[get_offset:get_offset + num_params].view(meta_info.shape).clone(), - meta_info.device_mesh, - meta_info.placements, - ) - get_offset += num_params + shard_sd.update(update_sharded_sd(bucket, param_to_sync, meta_sharded_sd)) dist.barrier() del bucket return shard_sd @@ -355,7 +352,7 @@ def model_setup(self): if isinstance(fsdp_transformer_layer_cls_to_wrap, str): fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap] modules = [] - for _, module in model.named_modules(): + for module in model.modules(): if module.__class__.__name__ in fsdp_transformer_layer_cls_to_wrap or \ (isinstance(module, nn.Embedding) and not model.config.tie_word_embeddings): modules.append(module) @@ -365,7 +362,7 @@ def model_setup(self): fully_shard(model, **fsdp_kwargs) if self.module_args.meta_init: - shard_dict = self.get_dtensor(model, args.load, self.module_args.groupgemm) + shard_dict = self.get_dtensor(model, args.load) model.load_state_dict(shard_dict, assign=True) del shard_dict From d51f38ef0ed17f74d2d2cb9ffded9021ef93b586 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 22 Sep 2025 19:24:00 +0800 Subject: [PATCH 11/13] balance bins with strict upbound --- .../algorithm/grpo_utils/packing_utils.py | 72 +++++++++++-------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/chatlearn/algorithm/grpo_utils/packing_utils.py b/chatlearn/algorithm/grpo_utils/packing_utils.py index f6b75934..86305af4 100644 --- a/chatlearn/algorithm/grpo_utils/packing_utils.py +++ b/chatlearn/algorithm/grpo_utils/packing_utils.py @@ -19,6 +19,7 @@ import torch import numpy as np +from sortedcontainers import SortedList def bin_packing(seq_len_list: List[int], max_train_token: int): """ @@ -53,34 +54,42 @@ def bin_packing(seq_len_list: List[int], max_train_token: int): bins_seqlen[best_bin_index].append(value) return list(bins_id), list(bins_seqlen) -def bin_packing_fix_bin(seq_len_list: List[int], bin_size: int): +def adjust_bins(bins_id, bins_num): """ - Implementation of best fit decreasing bin packing algorithm with fix bin size + Adjust bins to balance total sequence length. + First create a sorted list of (total_seq_len, -min_seq_len, sorted_bin_list). + This will make sure last element in sorted list is bin will largest total_seq_len + and smallest single sample sequence length. + For each round, pop the smallest bin from tail and add the smallest sample to head. + The loop will stop if moving sample will no longer increase balance. """ - seqlen_id_mapping = dict(enumerate(seq_len_list)) - sorted_mapping = dict(sorted(seqlen_id_mapping.items(), key=lambda item: item[1], reverse=True)) - bins_id = [[] for i in range(bin_size)] - bins_seqlen = [[] for i in range(bin_size)] - for key, value in sorted_mapping.items(): - min_sum = None - for id_, bin_ in enumerate(bins_seqlen): - bin_sum = value + sum(bin_) - if min_sum is None: - min_sum = bin_sum - best_bin_index = id_ - else: - if bin_sum < min_sum: - min_sum = bin_sum - best_bin_index = id_ - bins_id[best_bin_index].append(key) - bins_seqlen[best_bin_index].append(value) - # sort bins by seqlen in single bin - bins_seqlen_sum = [sum(bin_seqlen) for bin_seqlen in bins_seqlen] - sorted_bin = sorted(zip(bins_seqlen_sum, bins_id), reverse=True) - sorted_binseq = sorted(zip(bins_seqlen_sum, bins_seqlen), reverse=True) - _, bins_id = zip(*sorted_bin) - _, bins_seqlen = zip(*sorted_binseq) - return list(bins_id), list(bins_seqlen) + sorted_list = SortedList() + for i in range(len(bins_id)): + min_seq = bins_num[i][-1] if len(bins_num[i]) > 0 else 0 + sorted_list.add(( + sum(bins_num[i]), + -min_seq, + SortedList([(num, id_) for id_, num in zip(bins_id[i], bins_num[i])]))) + # Balance sorted_list + stop = False + while not stop: + min_sum, _, min_bin = sorted_list.pop(0) + max_sum, _, max_bin = sorted_list.pop(-1) + smallest_num, smallest_id = max_bin.pop(0) + if abs((max_sum - min_sum - 2 * smallest_num)) < max_sum - min_sum: + min_bin.add((smallest_num, smallest_id)) + sorted_list.add((max_sum - smallest_num, -max_bin[0][0], max_bin)) + sorted_list.add((min_sum + smallest_num, -min_bin[0][0], min_bin)) + else: + stop = True + max_bin.add((smallest_num, smallest_id)) + sorted_list.add((max_sum, -max_bin[0][0], max_bin)) + sorted_list.add((min_sum, -min_bin[0][0], min_bin)) + bins_id = [[item[1] for item in list_[2]] for list_ in sorted_list] + bins_seq = [[item[0] for item in list_[2]] for list_ in sorted_list] + bins_id.reverse() + bins_seq.reverse() + return bins_id, bins_seq def prepare_packing_attn_mask(total_seq_len_list: List[int], pad_size: int, dtype): total_seq_length = sum(total_seq_len_list) + pad_size @@ -128,8 +137,9 @@ def regroup_data_packing( for data_b in data_list ] # Get bin_packing result - bins_id, _ = bin_packing(seq_len_list=total_token_length, max_train_token=max_train_token) - bin_size = torch.tensor(len(bins_id)).cuda() + bins_id, bins_seq = bin_packing(seq_len_list=total_token_length, max_train_token=max_train_token) + local_bin_size = len(bins_id) + bin_size = torch.tensor(local_bin_size).cuda() # Get max_bin_size across all rank in same model replica # For megatron, all_reduce along mp group first and emp group second # For FSDP, all_reduce along default group @@ -137,7 +147,11 @@ def regroup_data_packing( process_group_list = [None] for pg in process_group_list: torch.distributed.all_reduce(bin_size, op=torch.distributed.ReduceOp.MAX, group=pg) - bins_id, _ = bin_packing_fix_bin(seq_len_list=total_token_length, bin_size=bin_size.cpu().item()) + max_bin_size = bin_size.cpu().item() + for i in range(max_bin_size - local_bin_size): + bins_id.append([]) + bins_seq.append([]) + bins_id, bins_seq = adjust_bins(bins_id, bins_seq) # Prepare train data for each micro batch for micro_batch_id in bins_id: regroup_data_list.append([]) From dba2eb74e1f21fded37d4811fad562243e183a0d Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 22 Sep 2025 19:43:41 +0800 Subject: [PATCH 12/13] fix pylint --- chatlearn/algorithm/grpo_utils/packing_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chatlearn/algorithm/grpo_utils/packing_utils.py b/chatlearn/algorithm/grpo_utils/packing_utils.py index 86305af4..13d2ef64 100644 --- a/chatlearn/algorithm/grpo_utils/packing_utils.py +++ b/chatlearn/algorithm/grpo_utils/packing_utils.py @@ -67,8 +67,8 @@ def adjust_bins(bins_id, bins_num): for i in range(len(bins_id)): min_seq = bins_num[i][-1] if len(bins_num[i]) > 0 else 0 sorted_list.add(( - sum(bins_num[i]), - -min_seq, + sum(bins_num[i]), + -min_seq, SortedList([(num, id_) for id_, num in zip(bins_id[i], bins_num[i])]))) # Balance sorted_list stop = False @@ -148,7 +148,7 @@ def regroup_data_packing( for pg in process_group_list: torch.distributed.all_reduce(bin_size, op=torch.distributed.ReduceOp.MAX, group=pg) max_bin_size = bin_size.cpu().item() - for i in range(max_bin_size - local_bin_size): + for _ in range(max_bin_size - local_bin_size): bins_id.append([]) bins_seq.append([]) bins_id, bins_seq = adjust_bins(bins_id, bins_seq) From 29b5410f673ce3dc66f4754dd80f05dcf7a740e1 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Tue, 30 Sep 2025 11:22:29 +0800 Subject: [PATCH 13/13] fix boundary --- chatlearn/algorithm/grpo_utils/packing_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chatlearn/algorithm/grpo_utils/packing_utils.py b/chatlearn/algorithm/grpo_utils/packing_utils.py index 13d2ef64..34ee9d1d 100644 --- a/chatlearn/algorithm/grpo_utils/packing_utils.py +++ b/chatlearn/algorithm/grpo_utils/packing_utils.py @@ -63,6 +63,8 @@ def adjust_bins(bins_id, bins_num): For each round, pop the smallest bin from tail and add the smallest sample to head. The loop will stop if moving sample will no longer increase balance. """ + if len(bins_id) == 1: + return bins_id, bins_num sorted_list = SortedList() for i in range(len(bins_id)): min_seq = bins_num[i][-1] if len(bins_num[i]) > 0 else 0