Skip to content

optimize FlashAttention kernel#74

Open
LeeWant wants to merge 1 commit intoWenyueh:mainfrom
LeeWant:flashattention
Open

optimize FlashAttention kernel#74
LeeWant wants to merge 1 commit intoWenyueh:mainfrom
LeeWant:flashattention

Conversation

@LeeWant
Copy link
Copy Markdown
Collaborator

@LeeWant LeeWant commented Apr 2, 2026

I noticed that the FlashAttention kernel has some redundant calculations during the calculation of qk. I have made modifications to this and provided testing code to support it.attention.py

The original KV block traversal logic is as follows. Each Q block will traverse all KV blocks. In the case of variable length sequences, if the length difference between sequences is too large, it will result in a large amount of redundancy. In the prefill stage, the current Q block only needs to traverse the current block and the previous KV block, without continuing to traverse the subsequent KV blocks.

    # Number of blocks to process
    num_blocks = tl.cdiv(seq_len, BLOCK_N)
    
    # Loop over K, V blocks
    for block_n in range(num_blocks):
        start_n = block_n * BLOCK_N
        offs_n = start_n + tl.arange(0, BLOCK_N)

optimized:

    # Loop over K, V blocks
    for block_n in range(start_m + 1):
        start_n = block_n * BLOCK_N
        offs_n = start_n + tl.arange(0, BLOCK_N)

I conducted performance testing on this, and the performance testing code is shown below.

import torch
import time
import triton 
import triton.language as tl



# ============================================================================
#  Flash Attention (O(N) memory)
# ============================================================================
@triton.jit
def flash_attention_kernel(
    Q, K, V, O,
    cu_seqlens_q_ptr,
    scale,
    is_performance,
    num_heads: tl.constexpr,
    num_kv_heads: tl.constexpr,
    head_dim: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """Flash Attention - O(N) memory via online softmax"""
    start_m = tl.program_id(0)
    off_h = tl.program_id(1)
    seq_idx = tl.program_id(2)
    
    kv_head_idx = off_h // (num_heads // num_kv_heads)
    
    seq_start = tl.load(cu_seqlens_q_ptr + seq_idx)
    seq_end = tl.load(cu_seqlens_q_ptr + seq_idx + 1)
    seq_len = seq_end - seq_start
    
    if start_m * BLOCK_M >= seq_len:
        return
    
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, head_dim)
    
    q_ptrs = Q + (seq_start + offs_m[:, None]) * num_heads * head_dim + off_h * head_dim + offs_d[None, :]
    mask_m = offs_m < seq_len
    q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0)
    
    # Online softmax - stores only O(BLOCK_M) values
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1e10
    acc = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32)
    
    
    if is_performance:
        num_blocks = start_m + 1
    else:
        num_blocks = tl.cdiv(seq_len, BLOCK_N)

    count = 0
    for block_n in range(num_blocks):
        count += 1
        start_n = block_n * BLOCK_N
        offs_n = start_n + tl.arange(0, BLOCK_N)
        mask_n = offs_n < seq_len
        
        k_ptrs = K + (seq_start + offs_n[None, :]) * num_kv_heads * head_dim + kv_head_idx * head_dim + offs_d[:, None]
        k = tl.load(k_ptrs, mask=mask_n[None, :], other=0.0)
        
        qk = tl.dot(q, k) * scale
        
        mask_causal = (offs_m[:, None] + seq_start) >= (offs_n[None, :] + seq_start)
        qk = tl.where(mask_causal & mask_n[None, :], qk, -1e10)
        
        # Online softmax update
        m_ij = tl.max(qk, axis=1)
        m_i_new = tl.maximum(m_i, m_ij)
        alpha = tl.exp(m_i - m_i_new)
        p = tl.exp(qk - m_i_new[:, None])
        
        acc = acc * alpha[:, None]
        
        v_ptrs = V + (seq_start + offs_n[:, None]) * num_kv_heads * head_dim + kv_head_idx * head_dim + offs_d[None, :]
        v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0)
        
        acc = acc + tl.dot(p.to(v.dtype), v)
        
        l_i = l_i * alpha + tl.sum(p, axis=1)
        m_i = m_i_new
    
    acc = acc / l_i[:, None]
    
    o_ptrs = O + (seq_start + offs_m[:, None]) * num_heads * head_dim + off_h * head_dim + offs_d[None, :]
    tl.store(o_ptrs, acc.to(O.dtype.element_ty), mask=mask_m[:, None])


def flash_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    cu_seqlens: torch.Tensor,
    scale: float,
    is_performance: bool,
    num_heads: int,
    num_kv_heads: int,
    head_dim: int,
) -> torch.Tensor:
    """Flash Attention - online softmax optimization"""
    q = q.contiguous()
    k = k.contiguous()
    v = v.contiguous()
    
    output = torch.empty_like(q)
    
    if head_dim <= 64:
        BLOCK_M, BLOCK_N = 64, 64
    elif head_dim <= 128:
        BLOCK_M, BLOCK_N = 32, 32
    else:
        BLOCK_M, BLOCK_N = 16, 16
    
    num_seqs = cu_seqlens.size(0) - 1
    cu_seqlens_cpu = cu_seqlens.cpu()
    max_seq_len = (cu_seqlens_cpu[1:] - cu_seqlens_cpu[:-1]).max().item()
    
    grid = (triton.cdiv(max_seq_len, BLOCK_M), num_heads, num_seqs)
    
    flash_attention_kernel[grid](
        q, k, v, output,
        cu_seqlens,
        scale,
        is_performance,
        num_heads=num_heads,
        num_kv_heads=num_kv_heads,
        head_dim=head_dim,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
    )
    
    return output


def setup_data(seq_lens, num_heads, num_kv_heads, head_dim):
    # Set computation device
    device = 'cuda'
    
    # Compute total number of tokens across all sequences
    total_tokens = sum(seq_lens)

    # Initialize random Q/K/V tensors in float16
    q = torch.randn(total_tokens, num_heads, head_dim, device=device, dtype=torch.float32)
    k = torch.randn(total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float32)
    v = torch.randn(total_tokens, num_kv_heads, head_dim, device=device, dtype=torch.float32)

    # Scaling factor for scaled dot-product attention
    scale = 1.0 / (head_dim ** 0.5)

    cu_seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device='cuda')

    return q, k, v, cu_seq_lens, scale


def benchmark(seq_lens=[15,90], num_heads=16, num_kv_heads=8, head_dim=128, num_iter=50):
    print(f"\n{'='*80}")
    print(f"Benchmark: {seq_lens} avg tokens (total: {sum(seq_lens)} tokens)")
    print(f"Heads: {num_heads}/{num_kv_heads}, Dim: {head_dim}")
    print(f"{'='*80}")
    
    q, k, v, cu_seqlens, scale = setup_data(seq_lens, num_heads, num_kv_heads, head_dim)
    
    results = {}
    outputs = {}
    
    # Performance Flash
    print("\n[Performance mode] Flash Attention (O(N), online softmax)...")
    is_performance = True
    for _ in range(5):
        _ = flash_attention(q, k, v, cu_seqlens, scale, is_performance, num_heads, num_kv_heads, head_dim)
    
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(num_iter):
        outputs['flash_perf'] = flash_attention(q, k, v, cu_seqlens, scale, is_performance, num_heads, num_kv_heads, head_dim)
    torch.cuda.synchronize()
    t = (time.perf_counter() - start) / num_iter
    results['Flash Attention (O(N)) - Performance'] = t
    print(f"      {t*1000:.3f} ms")
    
    # Naive Flash
    print("\n[Naive mode]       Flash Attention (O(N), online softmax)...")
    is_performance = False
    for _ in range(5):
        _ = flash_attention(q, k, v, cu_seqlens, scale, is_performance, num_heads, num_kv_heads, head_dim)
    
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(num_iter):
        outputs['flash_naive'] = flash_attention(q, k, v, cu_seqlens, scale, is_performance, num_heads, num_kv_heads, head_dim)
    torch.cuda.synchronize()
    t = (time.perf_counter() - start) / num_iter
    results['Flash Attention (O(N)) - Naive'] = t
    print(f"      {t*1000:.3f} ms")
    
    return results
    
def summarize_benchmark(results_list, seq_lens_list, num_heads, num_kv_heads, head_dim):

    print("\n" + "="*80)
    print(f"SUMMARY OF BENCHMARKS")
    print(f"Heads: {num_heads}/{num_kv_heads}, Dim: {head_dim}")
    print("="*80)
    
    for idx, results in enumerate(results_list):
        seq_lens = seq_lens_list[idx]
        print(f"Sequence lengths: {seq_lens} (total tokens: {sum(seq_lens)})")
        for method, t in results.items():
            print(f"  {method:<35} : {t*1000:.3f} ms per iteration")
        print("-"*80)


if __name__ == "__main__":
    print("\n" + "="*80)
    print("PREFILL ATTENTION BENCHMARK")
    print("Comparing: PyTorch (O(N²)) | Naive Triton (O(N²)) | Flash (O(N))")
    print("="*80)
    
    seq_lens_list = [[64,128,256], [128,256,512], [1024,1024,1024],[64,78,2048],[78,40,8192]]
    results_list = []

    for seq_lens in seq_lens_list:
        results = benchmark(seq_lens=seq_lens, head_dim=128, num_iter=50)
        results_list.append(results)

    summarize_benchmark(results_list, seq_lens_list, num_heads=16, num_kv_heads=8, head_dim=128)

Compared with the original FlashAttention, multiple test cases with different length sequences were set up. The output result is shown below:

================================================================================
SUMMARY OF BENCHMARKS
Heads: 16/8, Dim: 128
================================================================================
Sequence lengths: [64, 128, 256] (total tokens: 448)
  Flash Attention (O(N)) - Performance : 0.145 ms per iteration
  Flash Attention (O(N)) - Naive      : 0.143 ms per iteration
--------------------------------------------------------------------------------
Sequence lengths: [128, 256, 512] (total tokens: 896)
  Flash Attention (O(N)) - Performance : 0.219 ms per iteration
  Flash Attention (O(N)) - Naive      : 0.332 ms per iteration
--------------------------------------------------------------------------------
Sequence lengths: [1024, 1024, 1024] (total tokens: 3072)
  Flash Attention (O(N)) - Performance : 0.161 ms per iteration
  Flash Attention (O(N)) - Naive      : 0.167 ms per iteration
--------------------------------------------------------------------------------
Sequence lengths: [64, 78, 2048] (total tokens: 2190)
  Flash Attention (O(N)) - Performance : 2.476 ms per iteration
  Flash Attention (O(N)) - Naive      : 3.912 ms per iteration
--------------------------------------------------------------------------------
Sequence lengths: [78, 40, 8192] (total tokens: 8310)
  Flash Attention (O(N)) - Performance : 31.140 ms per iteration
  Flash Attention (O(N)) - Naive      : 56.613 ms per iteration
--------------------------------------------------------------------------------

When the length difference of the input sequence is not significant, the performance is basically the same. The larger the difference in sequence length within the group, the more obvious the performance optimization.

@LeeWant
Copy link
Copy Markdown
Collaborator Author

LeeWant commented Apr 2, 2026

Of course, the current solution is only applicable to the case of BLOCK_M==BLOCK_N. When the size of BLOCK_N is inconsistent, the following solution can be considered.

max_m = tl.minimum((start_m + 1) * BLOCK_M, seq_len)
num_kv_blocks = tl.cdiv(max_m, BLOCK_N)

for block_n in range(num_kv_blocks):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant