diff --git a/src/myvllm/layers/attention.py b/src/myvllm/layers/attention.py index a50eb38..38ba2f8 100644 --- a/src/myvllm/layers/attention.py +++ b/src/myvllm/layers/attention.py @@ -156,11 +156,9 @@ def flash_attention_varlen_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - 1e10 acc = tl.zeros([BLOCK_M, head_dim], dtype=tl.float32) - # Number of blocks to process - num_blocks = tl.cdiv(seq_len, BLOCK_N) # Loop over K, V blocks - for block_n in range(num_blocks): + for block_n in range(start_m + 1): start_n = block_n * BLOCK_N offs_n = start_n + tl.arange(0, BLOCK_N)