Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion skyrl-tx/tx/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,21 @@ def __call__(
updated_cache = (k, v)

# Attention (causal only during prefill, GQA handled natively by dot_product_attention)
# Expand mask to [B, 1, Q_len, KV_len] for cuDNN compatibility (no broadcasting)
Q_len = q.shape[1]
KV_len = k.shape[1]
mask = jnp.broadcast_to(
attention_mask[:, None, None, :].astype(bool),
(B, 1, Q_len, KV_len)
)
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=attention_mask[:, None, None, :].astype(bool),
mask=mask,
is_causal=kv_cache is None,
implementation="cudnn",
)
Comment on lines 132 to 140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

According to the JAX documentation for jax.nn.dot_product_attention with implementation="cudnn", the deterministic parameter must be set to True. Without this, the call might fail or not use the cuDNN kernel as intended.

Suggested change
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=attention_mask[:, None, None, :].astype(bool),
mask=mask,
is_causal=kv_cache is None,
implementation="cudnn",
)
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=mask,
is_causal=kv_cache is None,
implementation="cudnn",
deterministic=True,
)


output = attn_output.reshape(B, T, self.num_heads * self.head_dim)
Expand Down
Loading