@@ -334,7 +334,7 @@ def _bwd_dq_inner(
334334 V ,
335335 do ,
336336 m ,
337- Delta ,
337+ Di , # D (= delta) is pre-divided by ds_scale.
338338 sm_scale , # input
339339 # shared by Q/K/V.
340340 stride_qm ,
@@ -345,7 +345,6 @@ def _bwd_dq_inner(
345345 stride_vk ,
346346 stride_dropoutm ,
347347 stride_dropoutn , # stride for dropout
348- stride_deltam ,
349348 seqlen_q ,
350349 seqlen_k , #
351350 BLOCK_M2 : tl .constexpr , #
@@ -393,8 +392,6 @@ def _bwd_dq_inner(
393392 if HAS_PE :
394393 kT_pe_ptrs = K + offs_n [None , :] * stride_kn + offs_k_pe [:, None ] * stride_kk
395394 vT_ptrs = V + offs_n [None , :] * stride_vn + offs_k [:, None ] * stride_vk
396- # D (= delta) is pre-divided by ds_scale.
397- Di = tl .load (Delta + offs_m * stride_deltam , mask = mask_m , other = 0.0 )
398395 # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
399396 tl .static_assert (BLOCK_M2 % BLOCK_N2 == 0 )
400397 curr_n = start_n
@@ -514,6 +511,7 @@ def _bwd_dq_inner(
514511 "USE_EXP2" ,
515512 "IS_FP8" ,
516513 "USE_INT64_STRIDES" ,
514+ "ENABLE_SINK" ,
517515 ],
518516)
519517
@@ -523,11 +521,13 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
523521 Q ,
524522 K ,
525523 V ,
524+ Sink ,
526525 sm_scale ,
527526 DO ,
528527 DQ ,
529528 DK ,
530529 DV ,
530+ DSink ,
531531 M ,
532532 Delta ,
533533 stride_qb_in ,
@@ -603,6 +603,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
603603 DEBUG_TRITON : tl .constexpr ,
604604 DEBUG_TRITON_DETAIL : tl .constexpr ,
605605 USE_INT64_STRIDES : tl .constexpr ,
606+ ENABLE_SINK : tl .constexpr ,
606607):
607608 if USE_INT64_STRIDES :
608609 stride_qb = tl .cast (stride_qb_in , tl .int64 )
@@ -1053,8 +1054,20 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
10531054 else :
10541055 q_pe = None
10551056 do = tl .load (DO + adj_do + offs_do , mask = mask_q , other = 0.0 )
1056- m = tl .load (M + adj_delta + offs_m * stride_deltam , mask = offs_m < seqlen_q )
1057+ mask_m = offs_m < seqlen_q
1058+ m = tl .load (M + adj_delta + offs_m * stride_deltam , mask = mask_m , other = 0.0 )
10571059 m = m [:, None ]
1060+ delta = tl .load (Delta_ptr + offs_m * stride_deltam , mask = mask_m , other = 0.0 )
1061+
1062+ if ENABLE_SINK :
1063+ sink = tl .load (Sink + hqid ).to (tl .float32 )
1064+ if USE_EXP2 :
1065+ RCP_LN2 : tl .constexpr = 1.4426950408889634
1066+ psink = tl .math .exp2 (sink * RCP_LN2 - m * RCP_LN2 )
1067+ else :
1068+ psink = tl .math .exp (sink - m )
1069+ dsink = tl .sum (- psink * delta [:, None ])
1070+ tl .atomic_add (DSink + hqid , dsink , sem = "relaxed" )
10581071
10591072 MASK_BLOCK_N2 : tl .constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
10601073 # start can only be 0 at minimum
@@ -1083,7 +1096,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
10831096 V ,
10841097 do ,
10851098 m ,
1086- Delta_ptr ,
1099+ delta ,
10871100 sm_scale ,
10881101 stride_qm ,
10891102 stride_qd ,
@@ -1093,7 +1106,6 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
10931106 stride_vd ,
10941107 stride_dropoutm ,
10951108 stride_dropoutn ,
1096- stride_deltam ,
10971109 seqlen_q ,
10981110 seqlen_k ,
10991111 BLOCK_M2 ,
@@ -1139,7 +1151,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
11391151 V ,
11401152 do ,
11411153 m ,
1142- Delta_ptr ,
1154+ delta ,
11431155 sm_scale ,
11441156 stride_qm ,
11451157 stride_qd ,
@@ -1149,7 +1161,6 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
11491161 stride_vd ,
11501162 stride_dropoutm ,
11511163 stride_dropoutn ,
1152- stride_deltam ,
11531164 seqlen_q ,
11541165 seqlen_k ,
11551166 BLOCK_M2 ,
@@ -1208,6 +1219,7 @@ def bwd_kernel_causal( # grid = (tl.cdiv(max_seqlen_q // BLOCK_M2), batch, nhea
12081219 "USE_EXP2" ,
12091220 "IS_FP8" ,
12101221 "USE_INT64_STRIDES" ,
1222+ "ENABLE_SINK" ,
12111223 ],
12121224)
12131225
@@ -1217,11 +1229,13 @@ def bwd_kernel_noncausal(
12171229 Q ,
12181230 K ,
12191231 V ,
1232+ Sink ,
12201233 sm_scale ,
12211234 DO ,
12221235 DQ ,
12231236 DK ,
12241237 DV ,
1238+ DSink ,
12251239 M ,
12261240 Delta ,
12271241 stride_qb_in ,
@@ -1297,6 +1311,7 @@ def bwd_kernel_noncausal(
12971311 DEBUG_TRITON : tl .constexpr ,
12981312 DEBUG_TRITON_DETAIL : tl .constexpr ,
12991313 USE_INT64_STRIDES : tl .constexpr ,
1314+ ENABLE_SINK : tl .constexpr ,
13001315):
13011316 if USE_INT64_STRIDES :
13021317 stride_qb = tl .cast (stride_qb_in , tl .int64 )
@@ -1613,8 +1628,20 @@ def bwd_kernel_noncausal(
16131628 else :
16141629 q_pe = None
16151630 do = tl .load (DO + adj_do + offs_do , mask = mask_q , other = 0.0 )
1616- m = tl .load (M + adj_delta + offs_m * stride_deltam , mask = offs_m < seqlen_q )
1631+ mask_m = offs_m < seqlen_q
1632+ m = tl .load (M + adj_delta + offs_m * stride_deltam , mask = mask_m , other = 0.0 )
16171633 m = m [:, None ]
1634+ delta = tl .load (Delta_ptr + offs_m * stride_deltam , mask = mask_m , other = 0.0 )
1635+
1636+ if ENABLE_SINK :
1637+ sink = tl .load (Sink + hqid ).to (tl .float32 )
1638+ if USE_EXP2 :
1639+ RCP_LN2 : tl .constexpr = 1.4426950408889634
1640+ psink = tl .math .exp2 (sink * RCP_LN2 - m * RCP_LN2 )
1641+ else :
1642+ psink = tl .math .exp (sink - m )
1643+ dsink = tl .sum (- psink * delta [:, None ])
1644+ tl .atomic_add (DSink + hqid , dsink , sem = "relaxed" )
16181645
16191646 if IS_FP8 :
16201647 descale_q = tl .load (Descale_q + bid * stride_descale_q_z + hqid )
@@ -1643,7 +1670,7 @@ def bwd_kernel_noncausal(
16431670 V ,
16441671 do ,
16451672 m ,
1646- Delta_ptr ,
1673+ delta ,
16471674 sm_scale ,
16481675 stride_qm ,
16491676 stride_qd ,
@@ -1653,7 +1680,6 @@ def bwd_kernel_noncausal(
16531680 stride_vd ,
16541681 stride_dropoutm ,
16551682 stride_dropoutn ,
1656- stride_deltam ,
16571683 seqlen_q ,
16581684 seqlen_k ,
16591685 BLOCK_M2 ,
0 commit comments