33import torch .nn .functional as F
44from megatron .core .inference .contexts import BaseInferenceContext
55from megatron .core .packed_seq_params import PackedSeqParams
6- from typing import Optional
6+ from typing import List , Optional
77
88try :
99 from fla .modules .convolution import causal_conv1d
2121 _GatedDeltaNet = object
2222
2323
24+ # Code borrowed from NVIDIA/Megatron-LM
2425def _unpack_sequence (x , cu_seqlens , dim = 1 ):
2526 unpacked_x = []
2627 num_seqs = cu_seqlens .shape [0 ] - 1
@@ -32,6 +33,53 @@ def _unpack_sequence(x, cu_seqlens, dim=1):
3233 return unpacked_x
3334
3435
36+ # Code borrowed from NVIDIA/Megatron-LM
37+ # Avoid the warning caused by `param[slices]`
38+ def get_parameter_local_cp (
39+ param : torch .Tensor ,
40+ dim : int ,
41+ cp_group : torch .distributed .ProcessGroup ,
42+ split_sections : Optional [List [int ]] = None ,
43+ ) -> torch .Tensor :
44+ """Get the local parameter for the current context parallel rank.
45+
46+ Args:
47+ param (torch.Tensor): The entire parameter to get the local parameter for.
48+ dim (int): The dimension to split the parameter along. Usually the dimension of head.
49+ cp_group (torch.distributed.ProcessGroup): The context parallel group.
50+ split_sections (Optional[List[int]]): If not None,
51+ first split the parameter along the dimension dim into sections,
52+ then get the local hidden parallel weights separately,
53+ finally concatenate the local hidden parallel weights along the dimension dim.
54+
55+ Returns:
56+ torch.Tensor: The local parameter for the current context parallel rank.
57+ """
58+
59+ cp_size = cp_group .size ()
60+ cp_rank = cp_group .rank ()
61+
62+ # No need to split if CP size is 1.
63+ if cp_size == 1 :
64+ return param
65+
66+ # Split first if needed.
67+ if split_sections is not None :
68+ inputs = torch .split (param , split_sections , dim = dim )
69+ outputs = []
70+ for p in inputs :
71+ p = get_parameter_local_cp (p , dim , cp_group )
72+ outputs .append (p )
73+ return torch .cat (outputs , dim = dim )
74+
75+ # Slice the parameter.
76+ slices = [slice (None )] * param .dim ()
77+ dim_size = param .size (dim = dim )
78+ slices [dim ] = slice (cp_rank * dim_size // cp_size , (cp_rank + 1 ) * dim_size // cp_size )
79+ param = param [tuple (slices )]
80+ return param
81+
82+
3583class GatedDeltaNet (_GatedDeltaNet ):
3684
3785 def forward (
@@ -88,7 +136,7 @@ def forward(
88136 nvtx_range_pop (suffix = 'in_proj' )
89137
90138 if cp_size > 1 :
91- from megatron .core .ssm .gated_delta_net import get_parameter_local_cp , tensor_a2a_cp2hp , tensor_a2a_hp2cp
139+ from megatron .core .ssm .gated_delta_net import tensor_a2a_cp2hp , tensor_a2a_hp2cp
92140 if cu_seqlens is not None :
93141 unpacked_qkvzba = _unpack_sequence (qkvzba , cu_seqlens // self .cp_size , dim = 0 )
94142 outputs = []
0 commit comments