Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ run_test_config() {
export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests
run_default_fa 1 test_custom_call_compute.py
run_default_fa 1 test_functions.py
run 1 test_fused_attn.py
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
run 1 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # skip smallseq in normal flow
XLA_FLAGS='--xla_gpu_graph_level=0' run 1 test_fused_attn.py -k 'test_ck_unfused_smallseq_backend' # CK smallseq with GPU graph disabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For our new rocm7.2 image, the xla cudagraph disabling need to use

XLA_FLAGS="--xla_gpu_enable_command_buffer="

NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py -k 'not test_ck_unfused_smallseq_backend' # Using FAv2 for forward and backward pass
run_default_fa 1 test_helper.py
run_default_fa 1 test_layer.py #it effectevly always uses unfused attention
run_default_fa 1 test_sanity_import.py
Expand Down
169 changes: 149 additions & 20 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from functools import partial
from math import sqrt
from typing import Tuple, Optional, Dict
import os
import random

import jax
Expand Down Expand Up @@ -329,7 +330,11 @@ class FusedAttnRunner:
# generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases.
def _get_max_segments_per_sequence(self):
if self.qkv_layout.is_thd():
if 90400 <= get_cudnn_version() < 90500:
if (
90400 <= get_cudnn_version() < 90500
or ( is_hip_extension() and
os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1")
):
return self.num_segments_per_seq
else:
# +1 for testing runtime_segments < max_segments
Expand Down Expand Up @@ -418,6 +423,57 @@ def _check_configs(self):
"the F16_arbitrary_seqlen backend."
)

def _setup_thd_segments_ck_smallseq(self, generate_random_segment_ids):
"""
Build THD segment descriptors for the CK small-seq path (NVTE_FUSED_ATTN_CK_SMALLSEQ=1).

Uses num_segments_per_seq = max_seqlen_q for both Q and KV. For Q: if max_seqlen_q == 1,
uses a fixed layout (one token per batch, cu_seqlens [0,1,...,batch_size]); otherwise
generates random segments. For KV: always generates random segments.
"""
num_segments_per_seq = self.max_seqlen_q
if self.max_seqlen_q == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it run into problems if we call generate_random_segment_ids directly when self.max_seqlen_q==1?

# Q: deterministic - one segment of length 1 per batch -> cu_seqlen [0,1,2,...,batch_size]
segment_ids_q = jnp.ones((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
segment_pos_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
pad_q = jnp.zeros((self.batch_size, self.max_seqlen_q), dtype=jnp.int32)
seqlens_q = jnp.ones((self.batch_size, 1), dtype=jnp.int32)
offsets_q = jnp.concatenate(
[
jnp.arange(self.batch_size, dtype=jnp.int32)[:, None],
jnp.full((self.batch_size, 1), -1, dtype=jnp.int32),
],
axis=1,
)
else:
segment_ids_q, segment_pos_q, pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, num_segments_per_seq, seed=42
)
seqlens_q, offsets_q = get_seqlens_and_offsets(segment_ids_q)

min_segment_len = None if self.window_size is None else seqlens_q
segment_ids_kv, segment_pos_kv, pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
seqlens_kv, offsets_kv = get_seqlens_and_offsets(segment_ids_kv)
return (
num_segments_per_seq,
segment_ids_q,
segment_pos_q,
pad_q,
seqlens_q,
offsets_q,
segment_ids_kv,
segment_pos_kv,
pad_kv,
seqlens_kv,
offsets_kv,
)

def _setup_inputs(self):
self._check_configs()

Expand Down Expand Up @@ -539,27 +595,42 @@ def generate_random_segment_ids(
return segment_ids, segment_pos, segment_pad

if self.qkv_layout.is_thd():
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
if is_hip_extension() and os.environ.get("NVTE_FUSED_ATTN_CK_SMALLSEQ", "0") == "1":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe put the small_seq into self.config to replace the checking with ENV?

(
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
self.segment_ids_q,
self.segment_pos_q,
self.pad_q,
self.seqlens_q,
self.offsets_q,
self.segment_ids_kv,
self.segment_pos_kv,
self.pad_kv,
self.seqlens_kv,
self.offsets_kv,
) = self._setup_thd_segments_ck_smallseq(generate_random_segment_ids)
else:
self.num_segments_per_seq = 2
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
if self.qkv_layout == QKVLayout.T3HD or self.max_seqlen_q == self.max_seqlen_kv:
self.segment_ids_kv = self.segment_ids_q
self.segment_pos_kv = self.segment_pos_q
self.pad_kv = self.pad_q
else:
# Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support
min_segment_len = None if self.window_size is None else self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.num_segments_per_seq = 1
self.segment_ids_q, self.pad_q = gen_valid(
Expand Down Expand Up @@ -1214,3 +1285,61 @@ def test_jax_new_rng():
)
runner = FusedAttnRunner(**kwargs)
runner.test_forward()


# ROCm CK small-seq varlen tests.
@pytest.fixture
def ck_smallseq_env(monkeypatch):
"""Enable CK small-seq path and disable XLA GPU graphs for these tests."""
if "xla_gpu_graph_level=0" not in os.environ.get("XLA_FLAGS", ""):
pytest.skip("Test must be run with XLA_FLAGS='--xla_gpu_graph_level=0'")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure the new XLA_FLAG for cudagraph is due to the change of rocm or jax. If it's with the jax change, we can use a jax version check

monkeypatch.setenv("NVTE_FUSED_ATTN_CK_SMALLSEQ", "1")
yield

@pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16], ids=["BF16", "FP16"])
@pytest.mark.parametrize(
"b, s_q, s_kv, h_q, h_kv, d_qk, d_v",
[
pytest.param(4000, 1, 2, 16, 16, 128, 128, id="4000-1-2-16-16-128-128"),
pytest.param(4000, 1, 4, 16, 16, 128, 128, id="4000-1-4-16-16-128-128"),
pytest.param(4000, 1, 6, 16, 16, 128, 128, id="4000-1-6-16-16-128-128"),
pytest.param(4000, 1, 8, 16, 16, 128, 128, id="4000-1-8-16-16-128-128"),
pytest.param(4000, 1, 12, 16, 16, 128, 128, id="4000-1-12-16-16-128-128"),
pytest.param(4000, 1, 16, 16, 16, 128, 128, id="4000-1-16-16-16-128-128"),
pytest.param(2048, 2, 4, 16, 16, 128, 128, id="seqpack-2048-2-4-16-16-128-128"),
pytest.param(2, 4096, 8192, 16, 16, 128, 128, id="seqpack-2-4096-8192-16-16-128-128"),
],
)
@pytest.mark.skipif(
not is_hip_extension(), reason="CK unfused smallseq backend only available on AMD hardware"
)
def test_ck_unfused_smallseq_backend(
ck_smallseq_env, b, s_q, s_kv, h_q, h_kv, d_qk, d_v, dtype
):
"""
Test the CK unfused small-seq (varlen) path on ROCm: s_q=1, s_kv<=16, THD layout.
Uses THD_THD_THD (Q,K,V all THD). ck_smallseq_env sets NVTE_FUSED_ATTN_CK_SMALLSEQ=1 and
restores it after the test.
"""
runner = FusedAttnRunner(
batch_size=b,
max_seqlen_q=s_q,
max_seqlen_kv=s_kv,
num_heads_q=h_q,
num_heads_kv=h_kv,
head_dim_qk=d_qk,
head_dim_v=d_v,
attn_bias_type=AttnBiasType.NO_BIAS,
attn_mask_type=AttnMaskType.PADDING_MASK,
dropout_prob=0.0,
use_old_rng=True,
dtype=dtype,
is_training=True,
qkv_layout=QKVLayout.THD_THD_THD,
bias_shape=None,
window_size=None,
seq_desc_format=SeqDescFormat.Seqlens,
)
runner._setup_inputs()
# runner.test_forward()
runner.test_backward()
1 change: 1 addition & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ else()
fused_attn_rocm/fused_attn.cpp
fused_attn_rocm/fused_attn_aotriton.cpp
fused_attn_rocm/fused_attn_ck.cpp
fused_attn_rocm/fused_attn_smallseq.cpp
fused_attn_rocm/utils.cpp
gemm/rocm_gemm.cu
amd_detail/system.cpp)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
Expand Down Expand Up @@ -168,6 +168,12 @@ hipError_t ck_attn_varlen_bwd(
int how_v3_bf16_cvt,
hipStream_t stream);

uint64_t get_runtime_max_seqlen(uint64_t b,
const void* cu_seqlen_ptr,
const void* cu_seqlen_padded_ptr,
void* workspace,
hipStream_t stream);

}//namespace ck_fused_attn
#endif // CK_FUSED_ATTN_H

74 changes: 71 additions & 3 deletions transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*************************************************************************
* Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
Expand All @@ -9,6 +9,7 @@
#include <numeric> // Required for std::accumulate
#ifdef USE_FUSED_ATTN_CK
#include <ck_fused_attn/ck_fused_attn.hpp>
#include "fused_attn_smallseq.h"
#endif // USE_FUSED_ATTN_CK
#include "../util/cuda_runtime.h"
#include "../util/system.h"
Expand Down Expand Up @@ -614,6 +615,40 @@ void fused_attn_ck_fwd_impl(
// denote the next available section of workspace from upstream
void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add another filter s_q !=s_kv here

void* max_seqlen_workspace = workspace_next;
size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
static_cast<uint64_t>(b), devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream));
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we don't pass cu_seqlen_padded into runtime max seqlen check. What if max_seqlen without padding satisfy the ck kernel condition but with padding they do not? Can ck kernel handle those corner cases?

static_cast<uint64_t>(b), devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));
workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t));

if (nvte_log_ck_config) {
std::cout << std::endl << "attn_fwd(ck small-seq): ";
std::cout << "b: " << b << ", ";
std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", ";
std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", ";
std::cout << "flow: "
<< (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 &&
runtime_max_seqlen_kv <= 16
? "ck-smallseq"
: "regular ck/aiter")
<< std::endl;
}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {
fused_attn_rocm::fused_attn_smallseq_fwd(
b, h, hg, runtime_max_seqlen_kv, d_qk, d_v,
is_training, scaling_factor, dropout_probability,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxAux,
devPtrCuSeqlensKV, devPtrSeqOffsetsKV,
devPtrDropoutSeed, devPtrDropoutOffset,
dtype, workspace, workspace_size, stream);
return;
}
}

std::array<uint64_t, 4> q_stride;
std::array<uint64_t, 4> k_stride;
std::array<uint64_t, 4> v_stride;
Expand Down Expand Up @@ -916,6 +951,40 @@ void fused_attn_ck_bwd_impl(
// denote the next available section of workspace from upstream
void* workspace_next = workspace;

const char* nvte_smallseq = std::getenv("NVTE_FUSED_ATTN_CK_SMALLSEQ");
if (is_ragged && nvte_smallseq && std::string(nvte_smallseq) == "1") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, let's add s_q!=s_kv here

void* max_seqlen_workspace = workspace_next;
size_t runtime_max_seqlen_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensQ, nullptr, max_seqlen_workspace, stream));
size_t runtime_max_seqlen_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, devPtrCuSeqlensKV, nullptr, max_seqlen_workspace, stream));
workspace_next = static_cast<void*>(static_cast<int8_t*>(workspace_next) + sizeof(uint64_t));

if (nvte_log_ck_config) {
std::cout << std::endl << "attn_bwd(ck small-seq): ";
std::cout << "b: " << b << ", ";
std::cout << "runtime_max_seqlen_q: " << runtime_max_seqlen_q << ", ";
std::cout << "runtime_max_seqlen_kv: " << runtime_max_seqlen_kv << ", ";
std::cout << "flow: "
<< (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 &&
runtime_max_seqlen_kv <= 16
? "ck-smallseq"
: "regular ck/aiter")
<< std::endl;
}

if (runtime_max_seqlen_q == 1 && runtime_max_seqlen_kv >= 2 && runtime_max_seqlen_kv <= 16) {
fused_attn_rocm::fused_attn_smallseq_bwd(
b, h, hg, runtime_max_seqlen_kv, d_qk, d_v,
scaling_factor, dropout_probability,
devPtrQ, devPtrK, devPtrV, devPtrO, devPtrdO, devPtrSoftmaxAux,
devPtrdQ, devPtrdK, devPtrdV,
devPtrCuSeqlensKV, devPtrSeqOffsetsKV,
dtype, workspace, workspace_size, stream);
return;
}
}

std::array<uint64_t, 4> q_stride;
std::array<uint64_t, 4> k_stride;
std::array<uint64_t, 4> v_stride;
Expand Down Expand Up @@ -1828,7 +1897,7 @@ void fused_attn_ck_fwd(
size_t max_tokens_q = std::accumulate((input_Q->data).shape.begin(), (input_Q->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_q/d_qk;
size_t max_tokens_kv = std::accumulate((input_K->data).shape.begin(), (input_K->data).shape.end(), static_cast<size_t>(1), std::multiplies<size_t>())/h_kv/d_qk;

bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
bool is_ragged = nvte_get_qkv_format(qkv_layout)==NVTE_QKV_Format::NVTE_THD;
if (Aux_CTX_Tensors->size == 0) {
if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) {
Aux_CTX_Tensors->size = 3;
Expand Down Expand Up @@ -1883,7 +1952,6 @@ void fused_attn_ck_fwd(
bool is_padding = (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);

fused_attn_ck_fwd_impl(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, bias_b, bias_h,
max_tokens_q, max_tokens_kv,
Expand Down
Loading
Loading