Skip to content

Conversation

@raulchen
Copy link
Contributor

Summary

  • Add cuDNN flash attention for training and inference prefill, reducing memory from O(S²) to O(S)
  • Handles both left-padded (inference) and right-padded (training) sequences automatically
  • Falls back to mask-based attention for decode phase and CPU/TPU backends
  • Add GPU CI workflow for skyrl-tx

Changes

New shared attention module (tx/models/attention.py)

  • dot_product_attention(): Automatically selects cuDNN flash attention on GPU for causal attention, mask-based otherwise
  • Handles left-padded sequences by shifting to right-padded format for cuDNN compatibility
  • Decode phase uses mask-based attention (flash attention provides minimal benefit for single-token queries)

Model updates (llama3.py, qwen3.py)

  • Use shared dot_product_attention instead of inline jax.nn.dot_product_attention

Tests

  • CPU tests: _shift_sequences correctness, basic attention, GQA
  • GPU tests: Numerical equivalence between cuDNN and mask-based paths

CI

  • Add gpu_skyrl_tx.yaml workflow using Anyscale
  • Update cpu_skyrl_tx.yaml to exclude GPU tests

Test plan

  • Tested training/inference with longer sequences.

raulchen and others added 6 commits January 14, 2026 11:54
- Use seq_lengths instead of attention_mask for attention computation
- On GPU: use cuDNN flash attention with query_seq_lengths/key_value_seq_lengths
- On CPU/TPU: fall back to mask-based attention (construct mask from seq_lengths)
- cuDNN flash attention provides O(seq) memory vs O(seq²) for standard attention

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Extract shared attention logic to tx/models/attention.py
- Use cuDNN flash attention only for right-padded sequences on GPU
- Fall back to mask-based attention for left-padded (generation) or CPU/TPU
- Fixes generation bug where cuDNN received wrong valid positions
Shift left-padded sequences to right-padded before applying cuDNN flash
attention, then shift output back. This enables O(S^2) -> O(S) memory
savings for inference prefill while keeping mask-based attention for
decode (where flash attention provides minimal benefit).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use argmax to find first valid token position (0 for right-padded, >0 for left-padded)
- Always apply shift (no-op when shift=0), avoiding dual-branch compilation
- Document that attention_mask must have at least one valid token per batch

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add CPU tests for _shift_sequences and basic attention (tests/models/)
- Add GPU tests for cuDNN vs mask-based numerical equivalence (tests/gpu/)
- Add gpu_skyrl_tx.yaml workflow using Anyscale for GPU testing
- Update cpu_skyrl_tx.yaml to exclude tests/gpu/

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an excellent optimization by adding support for cuDNN flash attention, which will significantly reduce memory usage during training and prefill on GPUs. The implementation is robust, automatically handling both left- and right-padded sequences by cleverly converting them to the right-padded format required by cuDNN. The fallback to a standard mask-based attention for the decode phase and non-GPU backends is a sound design choice. The new functionality is accompanied by a comprehensive set of unit tests for both CPU and GPU, ensuring numerical correctness across different scenarios. The addition of a dedicated GPU CI workflow is also a great step towards maintaining the stability of this feature. I have one suggestion to further improve the test coverage for an edge case.

Comment on lines +84 to +99
B, T, H, D = 4, 128, 4, 64
q = jax.random.normal(jax.random.key(0), (B, T, H, D))
k = jax.random.normal(jax.random.key(1), (B, T, H, D))
v = jax.random.normal(jax.random.key(2), (B, T, H, D))

# Left-padded with different lengths
seq_lengths = jnp.array([128, 96, 64, 32])
padding = T - seq_lengths
mask = (jnp.arange(T)[None, :] >= padding[:, None]).astype(jnp.float32)

result = dot_product_attention(q, k, v, mask, is_causal=True, head_dim=D)
expected = mask_based_attention(q, k, v, mask, is_causal=True, head_dim=D)

for b in range(B):
pad_len = int(padding[b])
assert jnp.allclose(result[b, pad_len:], expected[b, pad_len:], atol=1e-5), f"Mismatch at batch {b}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test for mixed sequence lengths is great. To make it even more robust, I suggest including a zero-length (fully padded) sequence in the batch. This ensures that the attention mechanism correctly handles this edge case, which can occur in practice.

The current assertion only checks the valid parts of the sequence. For a zero-length sequence, this is an empty set, so the test would pass vacuously. Adding an explicit check that the output for the fully padded sequence is all zeros would be a valuable addition.

Suggested change
B, T, H, D = 4, 128, 4, 64
q = jax.random.normal(jax.random.key(0), (B, T, H, D))
k = jax.random.normal(jax.random.key(1), (B, T, H, D))
v = jax.random.normal(jax.random.key(2), (B, T, H, D))
# Left-padded with different lengths
seq_lengths = jnp.array([128, 96, 64, 32])
padding = T - seq_lengths
mask = (jnp.arange(T)[None, :] >= padding[:, None]).astype(jnp.float32)
result = dot_product_attention(q, k, v, mask, is_causal=True, head_dim=D)
expected = mask_based_attention(q, k, v, mask, is_causal=True, head_dim=D)
for b in range(B):
pad_len = int(padding[b])
assert jnp.allclose(result[b, pad_len:], expected[b, pad_len:], atol=1e-5), f"Mismatch at batch {b}"
B, T, H, D = 5, 128, 4, 64
q = jax.random.normal(jax.random.key(0), (B, T, H, D))
k = jax.random.normal(jax.random.key(1), (B, T, H, D))
v = jax.random.normal(jax.random.key(2), (B, T, H, D))
# Left-padded with different lengths, including a zero-length sequence
seq_lengths = jnp.array([128, 96, 64, 32, 0])
padding = T - seq_lengths
mask = (jnp.arange(T)[None, :] >= padding[:, None]).astype(jnp.float32)
result = dot_product_attention(q, k, v, mask, is_causal=True, head_dim=D)
expected = mask_based_attention(q, k, v, mask, is_causal=True, head_dim=D)
for b in range(B):
pad_len = int(padding[b])
valid_len = int(seq_lengths[b])
assert jnp.allclose(result[b, pad_len:], expected[b, pad_len:], atol=1e-5), f"Mismatch at batch {b}"
if valid_len == 0:
assert jnp.allclose(result[b], 0, atol=1e-6), f"Output for fully padded sequence at batch {b} should be zero."

@pcmoritz pcmoritz added the tx label Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants