Skip to content

Commit 63e9036

Browse files
authored
[gdn] support GDN CP (#16)
1 parent 27066a2 commit 63e9036

File tree

1 file changed

+80
-6
lines changed

1 file changed

+80
-6
lines changed

src/mcore_bridge/model/modules/gated_delta_net.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@
2121
_GatedDeltaNet = object
2222

2323

24+
def _unpack_sequence(x, cu_seqlens, dim=1):
25+
unpacked_x = []
26+
num_seqs = cu_seqlens.shape[0] - 1
27+
for i in range(num_seqs):
28+
idx_start = cu_seqlens[i].item()
29+
idx_end = cu_seqlens[i + 1].item()
30+
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
31+
unpacked_x.append(x[tuple(chunked_index)])
32+
return unpacked_x
33+
34+
2435
class GatedDeltaNet(_GatedDeltaNet):
2536

2637
def forward(
@@ -60,7 +71,8 @@ def forward(
6071
inference_context = deprecate_inference_params(inference_context, inference_params)
6172

6273
seq_len, batch, _ = hidden_states.shape
63-
seq_len = seq_len * self.sp_size
74+
cp_size = self.config.context_parallel_size
75+
seq_len = seq_len * self.sp_size * cp_size
6476

6577
if inference_context is not None:
6678
assert (
@@ -75,12 +87,35 @@ def forward(
7587
qkvzba, _ = self.in_proj(hidden_states)
7688
nvtx_range_pop(suffix='in_proj')
7789

90+
if cp_size > 1:
91+
from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp
92+
if cu_seqlens is not None:
93+
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0)
94+
outputs = []
95+
for qkvzba_i in unpacked_qkvzba:
96+
qkvzba_i = tensor_a2a_cp2hp(
97+
qkvzba_i,
98+
seq_dim=0,
99+
head_dim=-1,
100+
cp_group=self.pg_collection.cp,
101+
)
102+
outputs.append(qkvzba_i)
103+
qkvzba = torch.cat(outputs, dim=0)
104+
else:
105+
# CP All to All: CP to HP
106+
qkvzba = tensor_a2a_cp2hp(
107+
qkvzba,
108+
seq_dim=0,
109+
head_dim=-1,
110+
cp_group=self.pg_collection.cp,
111+
)
112+
78113
# Transpose: s b x --> b s x
79114
# From sbhd to bshd format
80115
qkvzba = qkvzba.transpose(0, 1)
81116

82117
# Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha
83-
num_key_heads_per_device = self.num_key_heads // self.tp_size
118+
num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size
84119
qkvzba = qkvzba.view(qkvzba.shape[:-1]
85120
+ (num_key_heads_per_device, qkvzba.shape[-1] // num_key_heads_per_device))
86121
qkv, gate, beta, alpha = torch.split(
@@ -100,17 +135,41 @@ def forward(
100135

101136
# Convolution on qkv
102137
nvtx_range_push(suffix='conv1d')
138+
if cp_size > 1:
139+
conv1d_weight = get_parameter_local_cp(
140+
self.conv1d.weight,
141+
dim=0,
142+
cp_group=self.pg_collection.cp,
143+
)
144+
conv1d_bias = (
145+
get_parameter_local_cp(
146+
self.conv1d.bias,
147+
dim=0,
148+
cp_group=self.pg_collection.cp,
149+
) if self.conv_bias else None)
150+
else:
151+
conv1d_weight = self.conv1d.weight
152+
conv1d_bias = self.conv1d.bias
153+
103154
if (causal_conv1d is None) or self.config.deterministic_mode:
104155
assert cu_seqlens is None, 'Packed sequences are not supported when fla is not available.'
105156
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
106-
qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len])
157+
conv_out = F.conv1d(
158+
input=qkv,
159+
weight=conv1d_weight,
160+
bias=conv1d_bias,
161+
stride=self.conv1d.stride,
162+
padding=self.conv1d.padding,
163+
dilation=self.conv1d.dilation,
164+
)
165+
qkv = self.act_fn(conv_out[..., :seq_len])
107166
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
108167
else:
109168
assert self.activation in ['silu', 'swish']
110169
qkv = causal_conv1d(
111170
x=qkv,
112-
weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w
113-
bias=self.conv1d.bias,
171+
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
172+
bias=conv1d_bias,
114173
activation=self.activation,
115174
cu_seqlens=cu_seqlens,
116175
)[0]
@@ -143,7 +202,12 @@ def forward(
143202

144203
# Calculate g and beta
145204
nvtx_range_push(suffix='g_and_beta')
146-
g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32
205+
if cp_size > 1:
206+
A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp)
207+
dt_bias_local_cp = get_parameter_local_cp(self.dt_bias, dim=0, cp_group=self.pg_collection.cp)
208+
else:
209+
A_log_local_cp, dt_bias_local_cp = self.A_log, self.dt_bias
210+
g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32
147211
beta = beta.sigmoid()
148212
nvtx_range_pop(suffix='g_and_beta')
149213

@@ -183,6 +247,16 @@ def forward(
183247
# From bshd back to sbhd format
184248
norm_out = norm_out.reshape(batch, seq_len, -1)
185249
norm_out = norm_out.transpose(0, 1).contiguous()
250+
if cp_size > 1:
251+
if cu_seqlens is not None:
252+
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0)
253+
outputs = []
254+
for norm_out_i in unpacked_norm_out:
255+
norm_out_i = tensor_a2a_hp2cp(norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)
256+
outputs.append(norm_out_i)
257+
norm_out = torch.cat(outputs, dim=0)
258+
else:
259+
norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)
186260

187261
# Output projection
188262
nvtx_range_push(suffix='out_proj')

0 commit comments

Comments
 (0)