From d26c9c8c3ba907b4d10f460974660b0174a19857 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Apr 2026 11:47:24 +0800 Subject: [PATCH] update get_parameter_local_cp --- .../model/modules/gated_delta_net.py | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 8d0b4ea..93b129a 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams -from typing import Optional +from typing import List, Optional try: from fla.modules.convolution import causal_conv1d @@ -21,6 +21,7 @@ _GatedDeltaNet = object +# Code borrowed from NVIDIA/Megatron-LM def _unpack_sequence(x, cu_seqlens, dim=1): unpacked_x = [] num_seqs = cu_seqlens.shape[0] - 1 @@ -32,6 +33,53 @@ def _unpack_sequence(x, cu_seqlens, dim=1): return unpacked_x +# Code borrowed from NVIDIA/Megatron-LM +# Avoid the warning caused by `param[slices]` +def get_parameter_local_cp( + param: torch.Tensor, + dim: int, + cp_group: torch.distributed.ProcessGroup, + split_sections: Optional[List[int]] = None, +) -> torch.Tensor: + """Get the local parameter for the current context parallel rank. + + Args: + param (torch.Tensor): The entire parameter to get the local parameter for. + dim (int): The dimension to split the parameter along. Usually the dimension of head. + cp_group (torch.distributed.ProcessGroup): The context parallel group. + split_sections (Optional[List[int]]): If not None, + first split the parameter along the dimension dim into sections, + then get the local hidden parallel weights separately, + finally concatenate the local hidden parallel weights along the dimension dim. + + Returns: + torch.Tensor: The local parameter for the current context parallel rank. + """ + + cp_size = cp_group.size() + cp_rank = cp_group.rank() + + # No need to split if CP size is 1. + if cp_size == 1: + return param + + # Split first if needed. + if split_sections is not None: + inputs = torch.split(param, split_sections, dim=dim) + outputs = [] + for p in inputs: + p = get_parameter_local_cp(p, dim, cp_group) + outputs.append(p) + return torch.cat(outputs, dim=dim) + + # Slice the parameter. + slices = [slice(None)] * param.dim() + dim_size = param.size(dim=dim) + slices[dim] = slice(cp_rank * dim_size // cp_size, (cp_rank + 1) * dim_size // cp_size) + param = param[tuple(slices)] + return param + + class GatedDeltaNet(_GatedDeltaNet): def forward( @@ -88,7 +136,7 @@ def forward( nvtx_range_pop(suffix='in_proj') if cp_size > 1: - from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp + from megatron.core.ssm.gated_delta_net import tensor_a2a_cp2hp, tensor_a2a_hp2cp if cu_seqlens is not None: unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0) outputs = []