From c3e7fb505048bde7e8a5bf63d51373422633f215 Mon Sep 17 00:00:00 2001 From: LeeWant <785498771@qq.com> Date: Thu, 2 Apr 2026 15:22:33 +0800 Subject: [PATCH] optimize FlashAttention kernel --- src/myvllm/layers/attention.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/myvllm/layers/attention.py b/src/myvllm/layers/attention.py index a50eb38..38ba2f8 100644 --- a/src/myvllm/layers/attention.py +++ b/src/myvllm/layers/attention.py @@ -156,11 +156,9 @@ def flash_attention_varlen_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1e10 acc = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32) - # Number of blocks to process - num_blocks = tl.cdiv(seq_len, BLOCK_N) # Loop over K, V blocks - for block_n in range(num_blocks): + for block_n in range(start_m + 1): start_n = block_n * BLOCK_N offs_n = start_n + tl.arange(0, BLOCK_N)