Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/mcore_bridge/model/modules/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from typing import Optional
from typing import List, Optional

try:
from fla.modules.convolution import causal_conv1d
Expand All @@ -21,6 +21,7 @@
_GatedDeltaNet = object


# Code borrowed from NVIDIA/Megatron-LM
def _unpack_sequence(x, cu_seqlens, dim=1):
unpacked_x = []
num_seqs = cu_seqlens.shape[0] - 1
Expand All @@ -32,6 +33,53 @@ def _unpack_sequence(x, cu_seqlens, dim=1):
return unpacked_x


# Code borrowed from NVIDIA/Megatron-LM
# Avoid the warning caused by `param[slices]`
def get_parameter_local_cp(
param: torch.Tensor,
dim: int,
cp_group: torch.distributed.ProcessGroup,
split_sections: Optional[List[int]] = None,
) -> torch.Tensor:
"""Get the local parameter for the current context parallel rank.

Args:
param (torch.Tensor): The entire parameter to get the local parameter for.
dim (int): The dimension to split the parameter along. Usually the dimension of head.
cp_group (torch.distributed.ProcessGroup): The context parallel group.
split_sections (Optional[List[int]]): If not None,
first split the parameter along the dimension dim into sections,
then get the local hidden parallel weights separately,
finally concatenate the local hidden parallel weights along the dimension dim.

Returns:
torch.Tensor: The local parameter for the current context parallel rank.
"""

cp_size = cp_group.size()
cp_rank = cp_group.rank()

# No need to split if CP size is 1.
if cp_size == 1:
return param

# Split first if needed.
if split_sections is not None:
inputs = torch.split(param, split_sections, dim=dim)
outputs = []
for p in inputs:
p = get_parameter_local_cp(p, dim, cp_group)
outputs.append(p)
return torch.cat(outputs, dim=dim)

# Slice the parameter.
slices = [slice(None)] * param.dim()
dim_size = param.size(dim=dim)
slices[dim] = slice(cp_rank * dim_size // cp_size, (cp_rank + 1) * dim_size // cp_size)
param = param[tuple(slices)]
return param


class GatedDeltaNet(_GatedDeltaNet):

def forward(
Expand Down Expand Up @@ -88,7 +136,7 @@ def forward(
nvtx_range_pop(suffix='in_proj')

if cp_size > 1:
from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp
from megatron.core.ssm.gated_delta_net import tensor_a2a_cp2hp, tensor_a2a_hp2cp
if cu_seqlens is not None:
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0)
outputs = []
Expand Down
Loading