Efficient, cross-platform chunked prefill attention for LLM inference with paged KV caching.
This project is a collection of high-performance paged attention implementations optimized for large language model (LLM) inference, supporting:
- ✅ Paged Attention (GPU/Metal-accelerated, memory-efficient key-value caching)
- ✅ Fused MoE WWMA fused MoE kernel for prefill and tiling based one for decoding
- ✅ Chunked Prefill for long sequences (prefill attention with kvcache, enabling
context cacheand chunked prefill) - ✅ Cross-platform: CUDA (NVIDIA GPUs,
V100, A100, H100, etc.) & Metal (Apple Silicon M1/M2/M3/M4) - ✅ Flexible API-first design with clean, reusable attention operations
- ✅ Softcapping, alibi slopes, sliding window attention, and multi-query/grouped-query attention
- Uses paged memory allocation for key/value caches to avoid padding and reduce memory waste.
- Supports variable-length sequences and dynamic batching.
- Ideal for long-context LLM inference (e.g., 32K+ tokens).
- For prefill phase with long sequences, splits attention into chunks (e.g., 4096 tokens) to avoid GPU memory limits.
- Avoids O(N²) memory blowup during prefill with dedicated
prefill_paged_attnkernels on CUDA and Metal. - Supports
cu_seqlens_qto track query lengths of multiple sequences (sub-query lengths per sequence).
| Backend | Supported? | Notes |
|---|---|---|
| CUDA (NVIDIA) | ✅ | Optimized CUDA kernels, supports both native and flash-attn |
| Metal (Apple M1/M2/M3/M4) | ✅ | Native Metal kernels for Apple Silicon |
- Unified
PagedAttentionAPI across backends. - Uses
candle-core'sCustomOpfor backend dispatch to backend-specific kernels. - No need to write separate code for GPU vs. Metal.
- ✅ Softcapping (for stable attention output)
- ✅ Alibi positional embeddings (for long-range dependencies)
- ✅ Sliding window attention (for efficiency)
- ✅ Multi-Query (MQA) & Grouped-Query Attention (GQA) support
- ✅ Dynamic block size & block table tables
[dependencies]
metal = { version = "0.27.0", features = ["mps"], optional = true }
candle-core = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", rev = "42c9b42" }
candle-flash-attn = { git = "https://github.com/guoqingbao/candle.git", version = "0.8.3", optional = true, rev = "42c9b42" }
attention-rs = {git = "https://github.com/guoqingbao/attention.rs", version="0.1.1", rev = "3058c20" }
# Your project features
[features]
cuda = ["candle-core/cuda", "candle-nn/cuda", "attention-rs/cuda"]
metal = ["candle-core/metal", "candle-nn/metal", "dep:metal", "attention-rs/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn", "attention-rs/flash-attn"]Note: Enable
cudaormetalfeatures based on your target.
use candle_core::{
candle_core::{Device, Tensor},
};
use attention_rs::{PagedAttention, InputMetadata},
// Example:
// Step 1: Setup
let device = Device::new_cuda(0).unwrap(); // or Device::new_metal(0) for Apple Silicon
let num_heads = 32;
let head_size = 128;
let num_kv_heads = 8;
let scale = (head_size as f32).sqrt().recip();
let paged_attn = PagedAttention::new(
num_heads,
head_size,
scale,
Some(num_kv_heads),
None,
device,
None,
)?;
// Input tensors (batch_size=1, seq_len=1024)
let query = Tensor::randnarrow(&query, 2, 0, 1024)?; // [1, 1024, 32*128]
let key = Tensor::narrow(&key, 2, 0, 1024)?; // [1, 256, 128]
let value = Tensor::narrow(&value, 2, 0, 1024)?; // [1, 256, 128]
// Paged cache setup (for non falsh attn)
let num_blocks = 1024;
let block_size = 64;
let num_kv_heads = 8;
let key_cache = Tensor::zeros(
(num_blocks,
num_kv_heads,
head_size / 8,
block_size,
8,
&device,
)?;
let value_cache = Tensor::zeros(
num_blocks,
num_kv_heads,
head_size,
block_size,
&device,
)?;
// Slot mapping: which block to which block
let slot_mapping = Tensor::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], (8,), &device)?;
// Context lengths per sequence
let context_lens = Tensor::from_slice(&[8], (1,), &device)?;
// Block tables: [num_sequences, max_num_blocks_per_seq]
let block_tables = Tensor::from_slice(&[0, 1], (1, 2), &device)?;
// Metadata for attention
let input_metadata = InputMetadata {
is_prefill: true,
slot_mapping,
block_tables: Some(block_tables),
context_lens: Some(context_lens),
cu_seqlens_q: None, // only needed for chunked prefill
cu_seqlens_k: None,
max_seqlen_q: 1024,
max_seqlen_k: 1024,
max_context_len: 1024,
};
// Step 2: Run attention
let output = paged_attn.forward(
&query,
&key,
&value,
None, // attention_mask
Some(key_cache),
Some(value_cache),
&input_metadata,
None, // softcapping
)?;For very long sequences (e.g., 32K+), use cu_seqlens_q** to split query tensor into chunks:
Referce usage in vllm.rs
//extra sequence length for query
let cu_seqlens_q = Tensor::from_slice(&[0, 4096, 8192, 12288], (4,), &device)?; // 3 sub sequences of length 4096, 4096, 4096 (or in total of 12288 tokens)
let input_metadata = InputMetadata {
is_prefill: true,
slot_mapping: slot_mapping,
block_tables: Some(block_tables),
context_lens: Some(context_lens),
cu_seqlens_q: Some(cu_seqlens_q),
cu_seqlens_k: None,
max_seqlen_q: 4096,
max_seqlen_k: 4096,
max_context_len: 12288,
};This project is licensed under the MIT License.
Contributions welcome! Please open an issue or PR on GitHub.
💡 ❤️ Used in vllm.rs and candle-vllm