diff --git a/skyrl-tx/tx/models/qwen3.py b/skyrl-tx/tx/models/qwen3.py index 9f01afbae..d9983f904 100644 --- a/skyrl-tx/tx/models/qwen3.py +++ b/skyrl-tx/tx/models/qwen3.py @@ -122,13 +122,21 @@ def __call__( updated_cache = (k, v) # Attention (causal only during prefill, GQA handled natively by dot_product_attention) + # Expand mask to [B, 1, Q_len, KV_len] for cuDNN compatibility (no broadcasting) + Q_len = q.shape[1] + KV_len = k.shape[1] + mask = jnp.broadcast_to( + attention_mask[:, None, None, :].astype(bool), + (B, 1, Q_len, KV_len) + ) attn_output = jax.nn.dot_product_attention( q, k, v, scale=1.0 / self.head_dim**0.5, - mask=attention_mask[:, None, None, :].astype(bool), + mask=mask, is_causal=kv_cache is None, + implementation="cudnn", ) output = attn_output.reshape(B, T, self.num_heads * self.head_dim)