Skip to content
93 changes: 53 additions & 40 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ def test_dot_product_attention(
flash_attn_supported = True

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
# Double-count the CK backend since we want to compare V2/V3 kernels
has_ck_backend = IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends
if not has_ck_backend and (
len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported
) < 2:
pytest.skip("Less than two backends to compare.")

# UnfusedDotProductAttention backend
Expand Down Expand Up @@ -271,8 +275,8 @@ def test_dot_product_attention(
)
if len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
Expand All @@ -284,11 +288,24 @@ def test_dot_product_attention(
is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
if has_ck_backend:
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
os.environ["NVTE_CK_USES_FWD_V3"] = "1"
os.environ["NVTE_CK_USES_BWD_V3"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
os.environ["NVTE_CK_USES_FWD_V3"] = "0"
os.environ["NVTE_CK_USES_BWD_V3"] = "0"
fused_attn_fwd_2, fused_attn_bwd_2 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
Expand All @@ -298,19 +315,6 @@ def test_dot_product_attention(
pad_between_seqs,
is_training,
)
if IS_HIP_EXTENSION:
os.environ["NVTE_CK_USES_FWD_V3"] = "0"
os.environ["NVTE_CK_USES_BWD_V3"] = "0"
fused_attn_fwd_2, fused_attn_bwd_2 = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)


# FlashAttention backend
Expand Down Expand Up @@ -347,11 +351,11 @@ def test_dot_product_attention(
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
if IS_HIP_EXTENSION:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 2")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
if has_ck_backend: # Compare CK V2/V3 if both are available
logging.info("[test_dot_product_attention]: CK fused attn V2 vs V3")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
Expand Down Expand Up @@ -1259,7 +1263,11 @@ def test_transformer_layer(
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
# Double-count the CK backend since we want to compare V2/V3 kernels
has_ck_backend = IS_HIP_EXTENSION and FusedAttnBackend["CK"] in fused_attn_backends
if not has_ck_backend and (
len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported
) < 2:
pytest.skip("Less than two backends to compare.")
# Skip if qkv_format = thd and "padding" not in attn_mask_type
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
Expand All @@ -1281,7 +1289,7 @@ def test_transformer_layer(

# FusedAttention backend
if fused_attn_supported:
if len(fused_attn_backends) == 1 or not IS_HIP_EXTENSION:
if len(fused_attn_backends) == 1:
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
Expand All @@ -1293,9 +1301,9 @@ def test_transformer_layer(
RoPE,
is_training,
)
elif len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
elif IS_HIP_EXTENSION and len(fused_attn_backends) == 2:
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_transformer_layer(
dtype,
config,
Expand All @@ -1307,8 +1315,8 @@ def test_transformer_layer(
RoPE,
is_training,
)
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
os.environ["NVTE_FUSED_ATTN_CK"] = "0"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_transformer_layer(
dtype,
config,
Expand All @@ -1321,6 +1329,9 @@ def test_transformer_layer(
is_training,
)

if has_ck_backend:
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
os.environ["NVTE_CK_USES_FWD_V3"] = "0"
os.environ["NVTE_CK_USES_BWD_V3"] = "0"
fused_attn_fwd_2, fused_attn_bwd_2 = _run_transformer_layer(
Expand Down Expand Up @@ -1363,15 +1374,17 @@ def test_transformer_layer(
logging.info("[test_transformer_layer]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
if IS_HIP_EXTENSION and fused_attn_supported and len(fused_attn_backends) == 2:
logging.info("[test_transformer_layer]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
logging.info("[test_transformer_layer]: fused attn backend 0 vs 2")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
if IS_HIP_EXTENSION and fused_attn_supported:
if len(fused_attn_backends) == 2:
logging.info("[test_transformer_layer]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
if has_ck_backend:
logging.info("[test_transformer_layer]: CK fused attn V2 vs V3")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)


@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
Expand Down