-
Notifications
You must be signed in to change notification settings - Fork 220
[tx] Add cuDNN flash attention #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Use seq_lengths instead of attention_mask for attention computation - On GPU: use cuDNN flash attention with query_seq_lengths/key_value_seq_lengths - On CPU/TPU: fall back to mask-based attention (construct mask from seq_lengths) - cuDNN flash attention provides O(seq) memory vs O(seq²) for standard attention Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Extract shared attention logic to tx/models/attention.py - Use cuDNN flash attention only for right-padded sequences on GPU - Fall back to mask-based attention for left-padded (generation) or CPU/TPU - Fixes generation bug where cuDNN received wrong valid positions
Shift left-padded sequences to right-padded before applying cuDNN flash attention, then shift output back. This enables O(S^2) -> O(S) memory savings for inference prefill while keeping mask-based attention for decode (where flash attention provides minimal benefit). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use argmax to find first valid token position (0 for right-padded, >0 for left-padded) - Always apply shift (no-op when shift=0), avoiding dual-branch compilation - Document that attention_mask must have at least one valid token per batch Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add CPU tests for _shift_sequences and basic attention (tests/models/) - Add GPU tests for cuDNN vs mask-based numerical equivalence (tests/gpu/) - Add gpu_skyrl_tx.yaml workflow using Anyscale for GPU testing - Update cpu_skyrl_tx.yaml to exclude tests/gpu/ Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an excellent optimization by adding support for cuDNN flash attention, which will significantly reduce memory usage during training and prefill on GPUs. The implementation is robust, automatically handling both left- and right-padded sequences by cleverly converting them to the right-padded format required by cuDNN. The fallback to a standard mask-based attention for the decode phase and non-GPU backends is a sound design choice. The new functionality is accompanied by a comprehensive set of unit tests for both CPU and GPU, ensuring numerical correctness across different scenarios. The addition of a dedicated GPU CI workflow is also a great step towards maintaining the stability of this feature. I have one suggestion to further improve the test coverage for an edge case.
| B, T, H, D = 4, 128, 4, 64 | ||
| q = jax.random.normal(jax.random.key(0), (B, T, H, D)) | ||
| k = jax.random.normal(jax.random.key(1), (B, T, H, D)) | ||
| v = jax.random.normal(jax.random.key(2), (B, T, H, D)) | ||
|
|
||
| # Left-padded with different lengths | ||
| seq_lengths = jnp.array([128, 96, 64, 32]) | ||
| padding = T - seq_lengths | ||
| mask = (jnp.arange(T)[None, :] >= padding[:, None]).astype(jnp.float32) | ||
|
|
||
| result = dot_product_attention(q, k, v, mask, is_causal=True, head_dim=D) | ||
| expected = mask_based_attention(q, k, v, mask, is_causal=True, head_dim=D) | ||
|
|
||
| for b in range(B): | ||
| pad_len = int(padding[b]) | ||
| assert jnp.allclose(result[b, pad_len:], expected[b, pad_len:], atol=1e-5), f"Mismatch at batch {b}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test for mixed sequence lengths is great. To make it even more robust, I suggest including a zero-length (fully padded) sequence in the batch. This ensures that the attention mechanism correctly handles this edge case, which can occur in practice.
The current assertion only checks the valid parts of the sequence. For a zero-length sequence, this is an empty set, so the test would pass vacuously. Adding an explicit check that the output for the fully padded sequence is all zeros would be a valuable addition.
| B, T, H, D = 4, 128, 4, 64 | |
| q = jax.random.normal(jax.random.key(0), (B, T, H, D)) | |
| k = jax.random.normal(jax.random.key(1), (B, T, H, D)) | |
| v = jax.random.normal(jax.random.key(2), (B, T, H, D)) | |
| # Left-padded with different lengths | |
| seq_lengths = jnp.array([128, 96, 64, 32]) | |
| padding = T - seq_lengths | |
| mask = (jnp.arange(T)[None, :] >= padding[:, None]).astype(jnp.float32) | |
| result = dot_product_attention(q, k, v, mask, is_causal=True, head_dim=D) | |
| expected = mask_based_attention(q, k, v, mask, is_causal=True, head_dim=D) | |
| for b in range(B): | |
| pad_len = int(padding[b]) | |
| assert jnp.allclose(result[b, pad_len:], expected[b, pad_len:], atol=1e-5), f"Mismatch at batch {b}" | |
| B, T, H, D = 5, 128, 4, 64 | |
| q = jax.random.normal(jax.random.key(0), (B, T, H, D)) | |
| k = jax.random.normal(jax.random.key(1), (B, T, H, D)) | |
| v = jax.random.normal(jax.random.key(2), (B, T, H, D)) | |
| # Left-padded with different lengths, including a zero-length sequence | |
| seq_lengths = jnp.array([128, 96, 64, 32, 0]) | |
| padding = T - seq_lengths | |
| mask = (jnp.arange(T)[None, :] >= padding[:, None]).astype(jnp.float32) | |
| result = dot_product_attention(q, k, v, mask, is_causal=True, head_dim=D) | |
| expected = mask_based_attention(q, k, v, mask, is_causal=True, head_dim=D) | |
| for b in range(B): | |
| pad_len = int(padding[b]) | |
| valid_len = int(seq_lengths[b]) | |
| assert jnp.allclose(result[b, pad_len:], expected[b, pad_len:], atol=1e-5), f"Mismatch at batch {b}" | |
| if valid_len == 0: | |
| assert jnp.allclose(result[b], 0, atol=1e-6), f"Output for fully padded sequence at batch {b} should be zero." |
Summary
Changes
New shared attention module (tx/models/attention.py)
Model updates (llama3.py, qwen3.py)
Tests
CI
Test plan