Skip to content

Commit 502fbeb

Browse files
[TRITON] Add attention sink support to Triton MHA kernels (ROCm#1576)
* Add attention sink support to forward pass * Add attention sink forward pass support to benchmark script * Add attention sink support to backward pass * Add attention sink backward pass support to benchmark script * Conditionally relax dv error toletance on `gfx942` * Decrease error tolerance for `dsink`
1 parent 3ba2cc1 commit 502fbeb

7 files changed

Lines changed: 527 additions & 57 deletions

File tree

aiter/ops/triton/_triton_kernels/mha.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _attn_fwd_inner(
271271
"VARLEN",
272272
"NUM_XCD",
273273
"USE_INT64_STRIDES",
274+
"ENABLE_SINK",
274275
],
275276
)
276277

@@ -288,6 +289,7 @@ def _attn_fwd(
288289
s_dmask_ptr: torch.Tensor,
289290
dropout_mask_ptr: torch.Tensor,
290291
softmax_lse_ptr: torch.Tensor,
292+
sink_ptr: torch.Tensor,
291293
stride_qz_in,
292294
stride_qh_in,
293295
stride_qm_in,
@@ -341,6 +343,7 @@ def _attn_fwd(
341343
BATCH,
342344
NUM_XCD: tl.constexpr,
343345
USE_INT64_STRIDES: tl.constexpr,
346+
ENABLE_SINK: tl.constexpr,
344347
):
345348
NUM_BLOCKS = (SEQLEN_Q + BLOCK_M - 1) // BLOCK_M
346349
# calculate offsets
@@ -631,7 +634,13 @@ def _attn_fwd(
631634
dropout_mask_ptrs = None
632635
philox_ptrs = None
633636

634-
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
637+
if ENABLE_SINK:
638+
RCP_LN2: tl.constexpr = 1.4426950408889634
639+
m_i_value = tl.load(sink_ptr + off_q_head).to(tl.float32) * RCP_LN2
640+
else:
641+
m_i_value = float("-inf")
642+
643+
m_i = tl.full([BLOCK_M], m_i_value, dtype=tl.float32)
635644
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
636645
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_POW2], dtype=tl.float32)
637646
if BLOCK_DMODEL == BLOCK_DMODEL_POW2:

aiter/ops/triton/_triton_kernels/mha_onekernel_bwd.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _bwd_dq_inner(
334334
V,
335335
do,
336336
m,
337-
Delta,
337+
Di, # D (= delta) is pre-divided by ds_scale.
338338
sm_scale, # input
339339
# shared by Q/K/V.
340340
stride_qm,
@@ -345,7 +345,6 @@ def _bwd_dq_inner(
345345
stride_vk,
346346
stride_dropoutm,
347347
stride_dropoutn, # stride for dropout
348-
stride_deltam,
349348
seqlen_q,
350349
seqlen_k, #
351350
BLOCK_M2: tl.constexpr, #
@@ -393,8 +392,6 @@ def _bwd_dq_inner(
393392
if HAS_PE:
394393
kT_pe_ptrs = K + offs_n[None, :] * stride_kn + offs_k_pe[:, None] * stride_kk
395394
vT_ptrs = V + offs_n[None, :] * stride_vn + offs_k[:, None] * stride_vk
396-
# D (= delta) is pre-divided by ds_scale.
397-
Di = tl.load(Delta + offs_m * stride_deltam, mask=mask_m, other=0.0)
398395
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
399396
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
400397
curr_n = start_n
@@ -514,6 +511,7 @@ def _bwd_dq_inner(
514511
"USE_EXP2",
515512
"IS_FP8",
516513
"USE_INT64_STRIDES",
514+
"ENABLE_SINK",
517515
],
518516
)
519517

@@ -523,11 +521,13 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
523521
Q,
524522
K,
525523
V,
524+
Sink,
526525
sm_scale,
527526
DO,
528527
DQ,
529528
DK,
530529
DV,
530+
DSink,
531531
M,
532532
Delta,
533533
stride_qb_in,
@@ -603,6 +603,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
603603
DEBUG_TRITON: tl.constexpr,
604604
DEBUG_TRITON_DETAIL: tl.constexpr,
605605
USE_INT64_STRIDES: tl.constexpr,
606+
ENABLE_SINK: tl.constexpr,
606607
):
607608
if USE_INT64_STRIDES:
608609
stride_qb = tl.cast(stride_qb_in, tl.int64)
@@ -1053,8 +1054,20 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
10531054
else:
10541055
q_pe = None
10551056
do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0)
1056-
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q)
1057+
mask_m = offs_m < seqlen_q
1058+
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=mask_m, other=0.0)
10571059
m = m[:, None]
1060+
delta = tl.load(Delta_ptr + offs_m * stride_deltam, mask=mask_m, other=0.0)
1061+
1062+
if ENABLE_SINK:
1063+
sink = tl.load(Sink + hqid).to(tl.float32)
1064+
if USE_EXP2:
1065+
RCP_LN2: tl.constexpr = 1.4426950408889634
1066+
psink = tl.math.exp2(sink * RCP_LN2 - m * RCP_LN2)
1067+
else:
1068+
psink = tl.math.exp(sink - m)
1069+
dsink = tl.sum(-psink * delta[:, None])
1070+
tl.atomic_add(DSink + hqid, dsink, sem="relaxed")
10581071

10591072
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
10601073
# start can only be 0 at minimum
@@ -1083,7 +1096,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
10831096
V,
10841097
do,
10851098
m,
1086-
Delta_ptr,
1099+
delta,
10871100
sm_scale,
10881101
stride_qm,
10891102
stride_qd,
@@ -1093,7 +1106,6 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
10931106
stride_vd,
10941107
stride_dropoutm,
10951108
stride_dropoutn,
1096-
stride_deltam,
10971109
seqlen_q,
10981110
seqlen_k,
10991111
BLOCK_M2,
@@ -1139,7 +1151,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
11391151
V,
11401152
do,
11411153
m,
1142-
Delta_ptr,
1154+
delta,
11431155
sm_scale,
11441156
stride_qm,
11451157
stride_qd,
@@ -1149,7 +1161,6 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
11491161
stride_vd,
11501162
stride_dropoutm,
11511163
stride_dropoutn,
1152-
stride_deltam,
11531164
seqlen_q,
11541165
seqlen_k,
11551166
BLOCK_M2,
@@ -1208,6 +1219,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
12081219
"USE_EXP2",
12091220
"IS_FP8",
12101221
"USE_INT64_STRIDES",
1222+
"ENABLE_SINK",
12111223
],
12121224
)
12131225

@@ -1217,11 +1229,13 @@ def bwd_kernel_noncausal(
12171229
Q,
12181230
K,
12191231
V,
1232+
Sink,
12201233
sm_scale,
12211234
DO,
12221235
DQ,
12231236
DK,
12241237
DV,
1238+
DSink,
12251239
M,
12261240
Delta,
12271241
stride_qb_in,
@@ -1297,6 +1311,7 @@ def bwd_kernel_noncausal(
12971311
DEBUG_TRITON: tl.constexpr,
12981312
DEBUG_TRITON_DETAIL: tl.constexpr,
12991313
USE_INT64_STRIDES: tl.constexpr,
1314+
ENABLE_SINK: tl.constexpr,
13001315
):
13011316
if USE_INT64_STRIDES:
13021317
stride_qb = tl.cast(stride_qb_in, tl.int64)
@@ -1613,8 +1628,20 @@ def bwd_kernel_noncausal(
16131628
else:
16141629
q_pe = None
16151630
do = tl.load(DO + adj_do + offs_do, mask=mask_q, other=0.0)
1616-
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=offs_m < seqlen_q)
1631+
mask_m = offs_m < seqlen_q
1632+
m = tl.load(M + adj_delta + offs_m * stride_deltam, mask=mask_m, other=0.0)
16171633
m = m[:, None]
1634+
delta = tl.load(Delta_ptr + offs_m * stride_deltam, mask=mask_m, other=0.0)
1635+
1636+
if ENABLE_SINK:
1637+
sink = tl.load(Sink + hqid).to(tl.float32)
1638+
if USE_EXP2:
1639+
RCP_LN2: tl.constexpr = 1.4426950408889634
1640+
psink = tl.math.exp2(sink * RCP_LN2 - m * RCP_LN2)
1641+
else:
1642+
psink = tl.math.exp(sink - m)
1643+
dsink = tl.sum(-psink * delta[:, None])
1644+
tl.atomic_add(DSink + hqid, dsink, sem="relaxed")
16181645

16191646
if IS_FP8:
16201647
descale_q = tl.load(Descale_q + bid * stride_descale_q_z + hqid)
@@ -1643,7 +1670,7 @@ def bwd_kernel_noncausal(
16431670
V,
16441671
do,
16451672
m,
1646-
Delta_ptr,
1673+
delta,
16471674
sm_scale,
16481675
stride_qm,
16491676
stride_qd,
@@ -1653,7 +1680,6 @@ def bwd_kernel_noncausal(
16531680
stride_vd,
16541681
stride_dropoutm,
16551682
stride_dropoutn,
1656-
stride_deltam,
16571683
seqlen_q,
16581684
seqlen_k,
16591685
BLOCK_M2,

0 commit comments

Comments
 (0)