From e01e9b565d9d4a802febcafe339af5fe6b57b5cd Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 12 Dec 2025 12:29:39 -0800 Subject: [PATCH 1/2] [tx] Try cudnn attention implementation --- skyrl-tx/tx/models/qwen3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 9f01afbae..f8c194442 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -129,6 +129,7 @@ def __call__( scale=1.0 / self.head_dim**0.5, mask=attention_mask[:, None, None, :].astype(bool), is_causal=kv_cache is None, + implementation="cudnn", ) output = attn_output.reshape(B, T, self.num_heads * self.head_dim) From b6a89fa97473a21bd47d0bca5f87c514a9377354 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Fri, 12 Dec 2025 12:44:57 -0800 Subject: [PATCH 2/2] update --- skyrl-tx/tx/models/qwen3.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index f8c194442..d9983f904 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -122,12 +122,19 @@ def __call__( updated_cache = (k, v) # Attention (causal only during prefill, GQA handled natively by dot_product_attention) + # Expand mask to [B, 1, Q_len, KV_len] for cuDNN compatibility (no broadcasting) + Q_len = q.shape[1] + KV_len = k.shape[1] + mask = jnp.broadcast_to( + attention_mask[:, None, None, :].astype(bool), + (B, 1, Q_len, KV_len) + ) attn_output = jax.nn.dot_product_attention( q, k, v, scale=1.0 / self.head_dim**0.5, - mask=attention_mask[:, None, None, :].astype(bool), + mask=mask, is_causal=kv_cache is None, implementation="cudnn", )