Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 1167 files
6 changes: 6 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa
to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case
for both the `FusedAttention` and `DotProductAttention` modules.

Certain settings can be enabled to potentially optimize workloads depending on the nature of the inputs and expected outputs:

* NVTE_CK_RUNTIME_NUM_SEGMENTS - by default 0, if set to 1 then the JAX integration will calculate the number of segments at runtime. Enabling this requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`.
* NVTE_CK_RUNTIME_MAX_SEQLEN - by default 0, if set to 1 then the max sequence length will be calculated at runtime. This can result in speedups in cases where there are many zero-length sequences. Enabling this while using the JAX integration requires also disabling the GPU graph by setting `XLA_FLAGS="--xla_gpu_graph_level=0"`.
* NVTE_CK_ZERO_OUT_PAD - by default 1, if set to 0 then the output of the FA forward pass will not be initialized to zero, meaning invalid regions (representing padding) may take nonzero values. Only used if input has padding.

AITER FA v3 Kernels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend.
Expand Down
24 changes: 24 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,30 @@ def reset_attn_backend():
param_types.append(torch.bfloat16)
param_types_lean = [torch.bfloat16]

# TODO: Enable config support in other backend(s) -- currently only the CK
# backend is capable of supporting it.
@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
def test_gqa_mla_thd():
"""
Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration
post-processing for BWD FA with native padding support.
"""
# b, sq, h, dqk
config = ModelConfig(8, 128, 16, 128, num_gqa_groups= 4, head_dim_v=64, attn_mask_type="padding")
qkv_layout = "thd_thd_thd"
dtype = torch.float16
_, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=True,
)
if FusedAttnBackend["CK"] not in fused_attn_backends:
pytest.skip("This test requires the CK fused attention backend.")

test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False)

@pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.")
def test_dot_product_mem_calc():
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ hipError_t ck_attn_fwd(
uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
hipStream_t stream);

hipError_t ck_attn_varlen_fwd(
Expand All @@ -72,6 +73,7 @@ hipError_t ck_attn_varlen_fwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
bool is_training,
float scaling_factor,
float dropout_probability,
Expand All @@ -82,6 +84,7 @@ hipError_t ck_attn_varlen_fwd(
uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_thd_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
hipStream_t stream);

hipError_t ck_attn_bwd(
Expand Down Expand Up @@ -137,6 +140,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand Down
132 changes: 99 additions & 33 deletions transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@

namespace ck_fused_attn{

// TODO: unify with binary search in TE/common/fused_attn(rocm)/util
// no device std::upper_bound
// in an increasing array with given size len, search for the index that:
// array[index] <= target < array[index+1]
// guaranteed that target >=0 and target <= cu_seqlen[end-1]
__forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}

// define dk_dv_reduce function only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce(
Expand Down Expand Up @@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce(
// define dk_dv_reduce function in THD layout only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_expanded,
const DataType *dv_expanded,
uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded,
Expand All @@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd(
uint64_t hdim_idx = threadIdx.x;

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd(
// When d_qk != d_v, we need to reduce dk and dv separately
template<typename DataType>
__global__ void dk_or_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_or_dv_expanded,
uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded,
DataType *dk_or_dv,
Expand All @@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd(

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -323,7 +355,7 @@ void log_bwd_config(const char* func_name,
std::cout<<std::endl<<func_name<<std::endl;

// fmha_traits debug
std::cout<<"fmha_traits: "<<std::endl;
std::cout<<std::endl<<"fmha_traits: "<<std::endl;
std::cout<<"hdim_q: "<<fmha_args.hdim_q<<std::endl;
std::cout<<"hdim_v: "<<fmha_args.hdim_v<<std::endl;
std::cout<<"data_type: "<<data_type_str<<std::endl;
Expand All @@ -339,7 +371,7 @@ void log_bwd_config(const char* func_name,
std::cout<<"how_v3_bf16_cvt: "<<how_v3_bf16_cvt<<std::endl;

// fmha_args debug
std::cout<<"fmha_args: "<<std::endl;
std::cout<<std::endl<<"fmha_args: "<<std::endl;
std::cout<<"q_ptr: "<<fmha_args.q_ptr<<std::endl;
std::cout<<"k_ptr: "<<fmha_args.k_ptr<<std::endl;
std::cout<<"v_ptr: "<<fmha_args.v_ptr<<std::endl;
Expand All @@ -353,9 +385,15 @@ void log_bwd_config(const char* func_name,
std::cout<<"dk_ptr: "<<fmha_args.dk_ptr<<std::endl;
std::cout<<"dv_ptr: "<<fmha_args.dv_ptr<<std::endl;
std::cout<<"dbias_ptr: "<<fmha_args.dbias_ptr<<std::endl;
std::cout<<"dq_acc_ptr: "<<fmha_args.dq_acc_ptr<<std::endl;

std::cout<<"seqstart_q_ptr: "<<fmha_args.seqstart_q_ptr<<std::endl;
std::cout<<"seqstart_k_ptr: "<<fmha_args.seqstart_k_ptr<<std::endl;
std::cout<<"seqlen_q_ptr: "<<fmha_args.seqlen_q_ptr<<std::endl;
std::cout<<"seqlen_k_ptr: "<<fmha_args.seqlen_k_ptr<<std::endl;
std::cout<<"cu_seqlen_q_ptr: "<<fmha_args.cu_seqlen_q_ptr<<std::endl;
std::cout<<"cu_seqlen_k_ptr: "<<fmha_args.cu_seqlen_k_ptr<<std::endl;

std::cout<<"seqlen_q: "<<fmha_args.seqlen_q<<std::endl;
std::cout<<"seqlen_k: "<<fmha_args.seqlen_k<<std::endl;
std::cout<<"batch: "<<fmha_args.batch<<std::endl;
Expand Down Expand Up @@ -572,9 +610,12 @@ hipError_t ck_attn_bwd(
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr,
dq_acc_ptr, //dq_acc_buf
nullptr,//cu_seqlen_q
nullptr,//cu_seqlen_kv
nullptr,//seqstart_q_ptr
nullptr,//seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
nullptr, //cu_seqlen_q_ptr
nullptr, //cu_seqlen_k_ptr
shape_seqlen_q,
shape_seqlen_k,
batch,
Expand Down Expand Up @@ -784,6 +825,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand Down Expand Up @@ -915,11 +957,14 @@ hipError_t ck_attn_varlen_bwd(
dq_ptr,
is_mqa_gqa? dk_expanded_ptr:dk_ptr,
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
nullptr,
nullptr, //dbias_ptr
dq_acc_ptr, //dq_acc_buf
cu_seqlen_q_ptr,//cu_seqlen_q
cu_seqlen_kv_ptr,//cu_seqlen_kv
cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr
cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
cu_seqlen_q_ptr, //cu_seqlen_q_ptr
cu_seqlen_kv_ptr, //cu_seqlen_k_ptr
max_seqlen_q, //seqlen_q, unused in group mode
max_seqlen_k, //seqlen_kv, unused in group mode
batch,
Expand Down Expand Up @@ -977,6 +1022,18 @@ hipError_t ck_attn_varlen_bwd(
std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}};
}();

// modify the max_seqlen_q for better performance in 0-length cases
// lse_thd_ptr used as buffer
if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) {
if(std::string(env_p) == "1"){
if(ck_fused_attn_log_config){
std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.";
}
fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream);
fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream);
}
}

// print ck traits and args when needed
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args);
if (uses_bwd_v3)
Expand All @@ -985,17 +1042,17 @@ hipError_t ck_attn_varlen_bwd(
}

float average_runtime = aiter::mha_bwd(fmha_args,
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
if(average_runtime < 0){
//TODO: better error out system
throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass.");
Expand All @@ -1006,6 +1063,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_dv_reduce_thd: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dkv_expanded: "<<stride_h_dk_expanded<<std::endl;
Expand All @@ -1018,8 +1077,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_dv_reduce_thd<CK_TILE_TYPE>, grid, block, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
Expand All @@ -1030,6 +1090,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dk(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dk: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"stride_h_dk_expanded: "<<stride_h_dk_expanded<<std::endl;
std::cout<<"stride_s_dk_expanded: "<<stride_s_dk_expanded<<std::endl;
Expand All @@ -1040,8 +1102,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
static_cast<CK_TILE_TYPE*>(dk_ptr),
Expand All @@ -1050,6 +1113,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dv(d_v);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dv: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dv_expanded: "<<stride_h_dv_expanded<<std::endl;
std::cout<<"stride_s_dv_expanded: "<<stride_s_dv_expanded<<std::endl;
Expand All @@ -1060,8 +1125,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dv, 0, stream,
h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dv_expanded, stride_s_dv_expanded,
static_cast<CK_TILE_TYPE*>(dv_ptr),
Expand Down
Loading