2121 _GatedDeltaNet = object
2222
2323
24+ def _unpack_sequence (x , cu_seqlens , dim = 1 ):
25+ unpacked_x = []
26+ num_seqs = cu_seqlens .shape [0 ] - 1
27+ for i in range (num_seqs ):
28+ idx_start = cu_seqlens [i ].item ()
29+ idx_end = cu_seqlens [i + 1 ].item ()
30+ chunked_index = [slice (None )] * dim + [slice (idx_start , idx_end )]
31+ unpacked_x .append (x [tuple (chunked_index )])
32+ return unpacked_x
33+
34+
2435class GatedDeltaNet (_GatedDeltaNet ):
2536
2637 def forward (
@@ -60,7 +71,8 @@ def forward(
6071 inference_context = deprecate_inference_params (inference_context , inference_params )
6172
6273 seq_len , batch , _ = hidden_states .shape
63- seq_len = seq_len * self .sp_size
74+ cp_size = self .config .context_parallel_size
75+ seq_len = seq_len * self .sp_size * cp_size
6476
6577 if inference_context is not None :
6678 assert (
@@ -75,12 +87,35 @@ def forward(
7587 qkvzba , _ = self .in_proj (hidden_states )
7688 nvtx_range_pop (suffix = 'in_proj' )
7789
90+ if cp_size > 1 :
91+ from megatron .core .ssm .gated_delta_net import get_parameter_local_cp , tensor_a2a_cp2hp , tensor_a2a_hp2cp
92+ if cu_seqlens is not None :
93+ unpacked_qkvzba = _unpack_sequence (qkvzba , cu_seqlens // self .cp_size , dim = 0 )
94+ outputs = []
95+ for qkvzba_i in unpacked_qkvzba :
96+ qkvzba_i = tensor_a2a_cp2hp (
97+ qkvzba_i ,
98+ seq_dim = 0 ,
99+ head_dim = - 1 ,
100+ cp_group = self .pg_collection .cp ,
101+ )
102+ outputs .append (qkvzba_i )
103+ qkvzba = torch .cat (outputs , dim = 0 )
104+ else :
105+ # CP All to All: CP to HP
106+ qkvzba = tensor_a2a_cp2hp (
107+ qkvzba ,
108+ seq_dim = 0 ,
109+ head_dim = - 1 ,
110+ cp_group = self .pg_collection .cp ,
111+ )
112+
78113 # Transpose: s b x --> b s x
79114 # From sbhd to bshd format
80115 qkvzba = qkvzba .transpose (0 , 1 )
81116
82117 # Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha
83- num_key_heads_per_device = self .num_key_heads // self .tp_size
118+ num_key_heads_per_device = self .num_key_heads // self .tp_size // cp_size
84119 qkvzba = qkvzba .view (qkvzba .shape [:- 1 ]
85120 + (num_key_heads_per_device , qkvzba .shape [- 1 ] // num_key_heads_per_device ))
86121 qkv , gate , beta , alpha = torch .split (
@@ -100,17 +135,41 @@ def forward(
100135
101136 # Convolution on qkv
102137 nvtx_range_push (suffix = 'conv1d' )
138+ if cp_size > 1 :
139+ conv1d_weight = get_parameter_local_cp (
140+ self .conv1d .weight ,
141+ dim = 0 ,
142+ cp_group = self .pg_collection .cp ,
143+ )
144+ conv1d_bias = (
145+ get_parameter_local_cp (
146+ self .conv1d .bias ,
147+ dim = 0 ,
148+ cp_group = self .pg_collection .cp ,
149+ ) if self .conv_bias else None )
150+ else :
151+ conv1d_weight = self .conv1d .weight
152+ conv1d_bias = self .conv1d .bias
153+
103154 if (causal_conv1d is None ) or self .config .deterministic_mode :
104155 assert cu_seqlens is None , 'Packed sequences are not supported when fla is not available.'
105156 qkv = qkv .transpose (1 , 2 ).contiguous () # b, s, d -> b, d, s
106- qkv = self .act_fn (self .conv1d (qkv )[..., :seq_len ])
157+ conv_out = F .conv1d (
158+ input = qkv ,
159+ weight = conv1d_weight ,
160+ bias = conv1d_bias ,
161+ stride = self .conv1d .stride ,
162+ padding = self .conv1d .padding ,
163+ dilation = self .conv1d .dilation ,
164+ )
165+ qkv = self .act_fn (conv_out [..., :seq_len ])
107166 qkv = qkv .transpose (1 , 2 ) # b, d, s -> b, s, d
108167 else :
109168 assert self .activation in ['silu' , 'swish' ]
110169 qkv = causal_conv1d (
111170 x = qkv ,
112- weight = self . conv1d . weight .squeeze (1 ), # d, 1, w -> d, w
113- bias = self . conv1d . bias ,
171+ weight = conv1d_weight .squeeze (1 ), # d, 1, w -> d, w
172+ bias = conv1d_bias ,
114173 activation = self .activation ,
115174 cu_seqlens = cu_seqlens ,
116175 )[0 ]
@@ -143,7 +202,12 @@ def forward(
143202
144203 # Calculate g and beta
145204 nvtx_range_push (suffix = 'g_and_beta' )
146- g = - self .A_log .exp () * F .softplus (alpha .float () + self .dt_bias ) # In fp32
205+ if cp_size > 1 :
206+ A_log_local_cp = get_parameter_local_cp (self .A_log , dim = 0 , cp_group = self .pg_collection .cp )
207+ dt_bias_local_cp = get_parameter_local_cp (self .dt_bias , dim = 0 , cp_group = self .pg_collection .cp )
208+ else :
209+ A_log_local_cp , dt_bias_local_cp = self .A_log , self .dt_bias
210+ g = - A_log_local_cp .exp () * F .softplus (alpha .float () + dt_bias_local_cp ) # In fp32
147211 beta = beta .sigmoid ()
148212 nvtx_range_pop (suffix = 'g_and_beta' )
149213
@@ -183,6 +247,16 @@ def forward(
183247 # From bshd back to sbhd format
184248 norm_out = norm_out .reshape (batch , seq_len , - 1 )
185249 norm_out = norm_out .transpose (0 , 1 ).contiguous ()
250+ if cp_size > 1 :
251+ if cu_seqlens is not None :
252+ unpacked_norm_out = _unpack_sequence (norm_out , cu_seqlens , dim = 0 )
253+ outputs = []
254+ for norm_out_i in unpacked_norm_out :
255+ norm_out_i = tensor_a2a_hp2cp (norm_out_i , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp )
256+ outputs .append (norm_out_i )
257+ norm_out = torch .cat (outputs , dim = 0 )
258+ else :
259+ norm_out = tensor_a2a_hp2cp (norm_out , seq_dim = 0 , head_dim = - 1 , cp_group = self .pg_collection .cp )
186260
187261 # Output projection
188262 nvtx_range_push (suffix = 'out_proj' )
0 commit comments