Skip to content

Commit 4636e98

Browse files
authored
update get_parameter_local_cp (#22)
1 parent 6756621 commit 4636e98

1 file changed

Lines changed: 50 additions & 2 deletions

File tree

src/mcore_bridge/model/modules/gated_delta_net.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44
from megatron.core.inference.contexts import BaseInferenceContext
55
from megatron.core.packed_seq_params import PackedSeqParams
6-
from typing import Optional
6+
from typing import List, Optional
77

88
try:
99
from fla.modules.convolution import causal_conv1d
@@ -21,6 +21,7 @@
2121
_GatedDeltaNet = object
2222

2323

24+
# Code borrowed from NVIDIA/Megatron-LM
2425
def _unpack_sequence(x, cu_seqlens, dim=1):
2526
unpacked_x = []
2627
num_seqs = cu_seqlens.shape[0] - 1
@@ -32,6 +33,53 @@ def _unpack_sequence(x, cu_seqlens, dim=1):
3233
return unpacked_x
3334

3435

36+
# Code borrowed from NVIDIA/Megatron-LM
37+
# Avoid the warning caused by `param[slices]`
38+
def get_parameter_local_cp(
39+
param: torch.Tensor,
40+
dim: int,
41+
cp_group: torch.distributed.ProcessGroup,
42+
split_sections: Optional[List[int]] = None,
43+
) -> torch.Tensor:
44+
"""Get the local parameter for the current context parallel rank.
45+
46+
Args:
47+
param (torch.Tensor): The entire parameter to get the local parameter for.
48+
dim (int): The dimension to split the parameter along. Usually the dimension of head.
49+
cp_group (torch.distributed.ProcessGroup): The context parallel group.
50+
split_sections (Optional[List[int]]): If not None,
51+
first split the parameter along the dimension dim into sections,
52+
then get the local hidden parallel weights separately,
53+
finally concatenate the local hidden parallel weights along the dimension dim.
54+
55+
Returns:
56+
torch.Tensor: The local parameter for the current context parallel rank.
57+
"""
58+
59+
cp_size = cp_group.size()
60+
cp_rank = cp_group.rank()
61+
62+
# No need to split if CP size is 1.
63+
if cp_size == 1:
64+
return param
65+
66+
# Split first if needed.
67+
if split_sections is not None:
68+
inputs = torch.split(param, split_sections, dim=dim)
69+
outputs = []
70+
for p in inputs:
71+
p = get_parameter_local_cp(p, dim, cp_group)
72+
outputs.append(p)
73+
return torch.cat(outputs, dim=dim)
74+
75+
# Slice the parameter.
76+
slices = [slice(None)] * param.dim()
77+
dim_size = param.size(dim=dim)
78+
slices[dim] = slice(cp_rank * dim_size // cp_size, (cp_rank + 1) * dim_size // cp_size)
79+
param = param[tuple(slices)]
80+
return param
81+
82+
3583
class GatedDeltaNet(_GatedDeltaNet):
3684

3785
def forward(
@@ -88,7 +136,7 @@ def forward(
88136
nvtx_range_pop(suffix='in_proj')
89137

90138
if cp_size > 1:
91-
from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp
139+
from megatron.core.ssm.gated_delta_net import tensor_a2a_cp2hp, tensor_a2a_hp2cp
92140
if cu_seqlens is not None:
93141
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0)
94142
outputs = []

0 commit comments

Comments
 (0)