From f00a8c264532f47d47f0cb0a929c97113a102d0e Mon Sep 17 00:00:00 2001 From: Bruce Changlong Xu Date: Sun, 1 Mar 2026 00:43:48 -0700 Subject: [PATCH] Fix block allocation for multi-token decode (speculative decoding) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit can_append and may_append assumed at most 1 new token per decode step. With speculative decoding (mtp_k > 0), the scheduler generates mtp_k + 1 tokens per step, which can cross multiple block boundaries. The old code under-allocated blocks, leading to out-of-bounds KV cache writes. can_append: - Old: boolean expression (len(seq) % block_size == 1) checked for 0 or 1 free blocks regardless of how many tokens are about to be generated. - New: accepts num_new_tokens, computes exact block deficit for the upcoming tokens. may_append: - Old: needed_blocks = ceil(seq_len / block_size) — only accounted for current sequence length, not the tokens about to be generated. Also, the elif branch at seq_len % block_size == 0 only updated the hash but did not allocate new blocks for upcoming tokens. - New: two-phase approach — (1) register hash if the last block just became full, (2) allocate blocks for ceil((seq_len + num_new_tokens) / block_size). Scheduler: - Hoisted num_new_tokens = mtp_k + 1 before the can_append loop so the check uses the correct token count. - Initialized num_rejected = 0 before the speculative decoding branch in postprocess to fix an UnboundLocalError on the non-speculative path. Tests: - Added 8 new can_append tests covering block boundaries, multi-token allocation, exact-fit, and insufficient-block scenarios. - Added 4 new may_append tests for multi-token allocation, boundary hash registration, and block_size=1 with multiple tokens. - Added TestPrefixCachingDecode class (3 tests): hash registration during decode, cache reuse by new sequences, and multi-step prefix building across decode iterations. - Fixed ScheduledBatchOutput constructor calls in test_scheduler.py to include num_rejected and num_bonus parameters. --- atom/model_engine/block_manager.py | 72 +++++------ atom/model_engine/scheduler.py | 5 +- tests/test_block_manager.py | 195 ++++++++++++++++++++++++++++- tests/test_scheduler.py | 12 +- 4 files changed, 239 insertions(+), 45 deletions(-) diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index df6ebd2c3..4cab2c10f 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -122,45 +122,21 @@ def deallocate(self, seq: Sequence): self._deallocate_block(block_id) seq.mamba_block_table.clear() - def can_append(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool: + needed_blocks = ( + (len(seq) + num_new_tokens + self.block_size - 1) // self.block_size + ) + blocks_to_allocate = needed_blocks - len(seq.block_table) + return len(self.free_block_ids) >= max(0, blocks_to_allocate) def may_append(self, seq: Sequence, num_new_tokens: int = 1): block_table = seq.block_table last_block = self.blocks[block_table[-1]] seq_len = len(seq) - # Check if we need to allocate a new block - # When len(seq) % block_size == 1, we need a new block for the next token - # When block_size == 1, every token needs a new block - if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1: - needed_blocks = (seq_len + self.block_size - 1) // self.block_size - while len(block_table) < needed_blocks: - # For block_size == 1, we need to update hash for each new block - # For block_size > 1, the previous block should have hash != -1 (unless it's the first block) - if self.block_size == 1: - # Allocate new block and update hash immediately (like allocate does for full blocks) - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - token_ids = [seq[-1]] - prefix = ( - self.blocks[block_table[-2]].hash - if len(block_table) > 1 - else -1 - ) - h = self.compute_hash(token_ids, prefix) - block.update(h, token_ids) - self.hash_to_block_id[h] = block_id - else: - # For block_size > 1, we only allocate new block when needed - # The hash will be updated when the block becomes full - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - last_block = block - elif seq_len % self.block_size == 0: - # Last block is now full, update its hash (similar to allocate) - # TODO: fix hash + + # Phase 1: If the last block just became full, register its hash + # so it can be reused for prefix caching on future sequences. + if seq_len % self.block_size == 0 and self.block_size > 1: token_ids = seq.block(seq.num_blocks - 1) if len(token_ids) == self.block_size: prefix = ( @@ -169,8 +145,26 @@ def may_append(self, seq: Sequence, num_new_tokens: int = 1): h = self.compute_hash(token_ids, prefix) last_block.update(h, token_ids) self.hash_to_block_id[h] = last_block.block_id - else: - pass - # Last block is not full and not at the boundary - # Hash remains -1 until block is full (consistent with allocate logic) - # assert last_block.hash == -1, last_block.block_id + + # Phase 2: Allocate new blocks for the upcoming tokens. + needed_blocks = ( + (seq_len + num_new_tokens + self.block_size - 1) // self.block_size + ) + while len(block_table) < needed_blocks: + if self.block_size == 1: + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + block_table.append(block_id) + token_ids = [seq[-1]] + prefix = ( + self.blocks[block_table[-2]].hash + if len(block_table) > 1 + else -1 + ) + h = self.compute_hash(token_ids, prefix) + block.update(h, token_ids) + self.hash_to_block_id[h] = block_id + else: + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + block_table.append(block_id) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index 46405682f..97b37e2e1 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -307,9 +307,10 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: # decode num_seqs_decode = 0 + num_new_tokens = self.mtp_k + 1 while self.running and num_seqs_decode < self.max_num_seqs: seq = self.running.popleft() - while not self.block_manager.can_append(seq): + while not self.block_manager.can_append(seq, num_new_tokens): if self.running: self.preempt(self.running.pop()) else: @@ -319,7 +320,6 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: if seq.spec_token_ids.size > 0: scheduled_spec_decode_tokens[seq.id] = seq.spec_token_ids num_seqs_decode += 1 - num_new_tokens = self.mtp_k + 1 self.block_manager.may_append(seq, num_new_tokens) scheduled_seqs[seq.id] = seq seq.type = SequenceType.DECODE @@ -382,6 +382,7 @@ def postprocess( if self.spec_stats: self.spec_stats.update(num_new_token) idx = fwd_output.req_ids.index(seq.id) + num_rejected = 0 if is_deferred_out or self.use_spec: num_rejected = fwd_output.num_rejected[idx] num_bonus = fwd_output.num_bonus[idx] diff --git a/tests/test_block_manager.py b/tests/test_block_manager.py index 9dfb5d484..5732a20d7 100644 --- a/tests/test_block_manager.py +++ b/tests/test_block_manager.py @@ -149,12 +149,85 @@ def test_cannot_append_no_free(self, seq_factory): seq.append_token(5) assert not bm.can_append(seq) + def test_at_block_boundary_needs_block(self, seq_factory): + """seq_len=4, block_size=4 → at boundary, 1 new token needs 1 new block.""" + cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4]) + bm.allocate(seq) + assert bm.can_append(seq, num_new_tokens=1) + + def test_at_block_boundary_no_free(self, seq_factory): + """seq_len=4, block_size=4, 0 free blocks → cannot append.""" + cfg = MockConfig(num_kvcache_blocks=1, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4]) + bm.allocate(seq) + assert not bm.can_append(seq, num_new_tokens=1) + + def test_multi_token_needs_two_blocks(self, seq_factory): + """seq_len=7, block_size=4, num_new_tokens=4 → total 11, needs 3 blocks, + 2 allocated, need 1 more. With only 1 free block, should succeed.""" + cfg = MockConfig(num_kvcache_blocks=3, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4, 5, 6, 7]) + bm.allocate(seq) + assert len(seq.block_table) == 2 + assert bm.can_append(seq, num_new_tokens=4) + + def test_multi_token_not_enough_free(self, seq_factory): + """seq_len=5, block_size=4, num_new_tokens=4 → total 9, needs 3 blocks, + 2 allocated, need 1 more. With 0 free blocks, should fail.""" + cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4, 5]) + bm.allocate(seq) + assert len(seq.block_table) == 2 + assert not bm.can_append(seq, num_new_tokens=4) + + def test_multi_token_enough_free(self, seq_factory): + """seq_len=7, block_size=4, num_new_tokens=4 → needs 1 more block. + With enough free blocks, should succeed.""" + cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4, 5, 6, 7]) + bm.allocate(seq) + assert bm.can_append(seq, num_new_tokens=4) + + def test_multi_token_crosses_two_boundaries(self, seq_factory): + """seq_len=5, block_size=4, num_new_tokens=4 → total 9, needs 3 blocks, + but only 2 allocated. Need 1 more free block.""" + cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4, 5]) + bm.allocate(seq) + assert len(seq.block_table) == 2 + assert bm.can_append(seq, num_new_tokens=4) + + def test_multi_token_exact_fit(self, seq_factory): + """seq_len=4, block_size=4, num_new_tokens=4 → total 8, needs 2 blocks. + With exactly 1 free block, should succeed.""" + cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4]) + bm.allocate(seq) + assert bm.can_append(seq, num_new_tokens=4) + + def test_multi_token_one_short(self, seq_factory): + """seq_len=4, block_size=4, num_new_tokens=5 → total 9, needs 3 blocks. + With only 1 free block, should fail.""" + cfg = MockConfig(num_kvcache_blocks=2, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4]) + bm.allocate(seq) + assert not bm.can_append(seq, num_new_tokens=5) + class TestMayAppend: def test_no_new_block_within_boundary(self, block_manager, seq_factory): - seq = seq_factory([1, 2, 3]) + seq = seq_factory([1, 2]) block_manager.allocate(seq) - seq.append_token(4) + seq.append_token(3) block_manager.may_append(seq) assert len(seq.block_table) == 1 @@ -166,10 +239,128 @@ def test_new_block_on_boundary_crossing(self, block_manager, seq_factory): assert len(seq.block_table) == 2 def test_block_size_1(self, seq_factory): + """block_size=1: seq=[1,2] → 2 blocks. append(3) → seq_len=3. + may_append(num_new_tokens=1) → needs ceil((3+1)/1) = 4 blocks.""" cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=1) bm = BlockManager(cfg) seq = seq_factory([1, 2], block_size=1) bm.allocate(seq) seq.append_token(3) bm.may_append(seq) + assert len(seq.block_table) == 4 + + def test_multi_token_allocates_enough_blocks(self, seq_factory): + """seq_len=5, block_size=4, num_new_tokens=4 → total 9, needs 3 blocks.""" + cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4, 5]) + bm.allocate(seq) + assert len(seq.block_table) == 2 + bm.may_append(seq, num_new_tokens=4) + assert len(seq.block_table) == 3 + + def test_multi_token_at_boundary(self, seq_factory): + """seq_len=4, block_size=4, num_new_tokens=4 → total 8, needs 2 blocks.""" + cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4]) + bm.allocate(seq) + assert len(seq.block_table) == 1 + bm.may_append(seq, num_new_tokens=4) + assert len(seq.block_table) == 2 + + def test_multi_token_crosses_two_boundaries(self, seq_factory): + """seq_len=4, block_size=4, num_new_tokens=5 → total 9, needs 3 blocks.""" + cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3, 4]) + bm.allocate(seq) + assert len(seq.block_table) == 1 + bm.may_append(seq, num_new_tokens=5) assert len(seq.block_table) == 3 + + def test_hash_registered_at_boundary(self, seq_factory): + """When seq fills a block exactly, may_append should register its hash.""" + cfg = MockConfig( + num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3]) + bm.allocate(seq) + seq.append_token(4) + bm.may_append(seq, num_new_tokens=1) + last_block = bm.blocks[seq.block_table[0]] + assert last_block.hash != -1 + assert last_block.hash in bm.hash_to_block_id + + def test_block_size_1_multi_token(self, seq_factory): + """block_size=1: seq=[1,2] → 2 blocks. append(3) → seq_len=3. + may_append(num_new_tokens=3) → needs ceil((3+3)/1) = 6 blocks.""" + cfg = MockConfig(num_kvcache_blocks=10, kv_cache_block_size=1) + bm = BlockManager(cfg) + seq = seq_factory([1, 2], block_size=1) + bm.allocate(seq) + assert len(seq.block_table) == 2 + seq.append_token(3) + bm.may_append(seq, num_new_tokens=3) + assert len(seq.block_table) == 6 + + +# ── Prefix caching during decode ────────────────────────────────────────── + + +class TestPrefixCachingDecode: + def test_hash_registered_during_decode(self, seq_factory): + """Block completed during decode should register its hash for reuse.""" + cfg = MockConfig( + num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + seq = seq_factory([1, 2, 3]) + bm.allocate(seq) + seq.append_token(4) + bm.may_append(seq, num_new_tokens=1) + + block = bm.blocks[seq.block_table[0]] + expected_hash = BlockManager.compute_hash([1, 2, 3, 4]) + assert block.hash == expected_hash + assert bm.hash_to_block_id[expected_hash] == block.block_id + + def test_decode_block_reused_by_new_sequence(self, seq_factory): + """A block completed and hashed during decode should be a cache hit + for a new sequence with the same prefix.""" + cfg = MockConfig( + num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + + s1 = seq_factory([1, 2, 3]) + bm.allocate(s1) + s1.append_token(4) + bm.may_append(s1, num_new_tokens=1) + bm.deallocate(s1) + + s2 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s2) + assert s2.num_cached_tokens == 4 + + def test_multi_step_decode_builds_prefix(self, seq_factory): + """Simulate multiple decode steps filling blocks, then verify + a new sequence gets cache hits on the completed blocks.""" + cfg = MockConfig( + num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + + seq = seq_factory([1, 2, 3, 4]) + bm.allocate(seq) + + for tok in [5, 6, 7, 8]: + seq.append_token(tok) + bm.may_append(seq, num_new_tokens=1) + + bm.deallocate(seq) + + s2 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8, 9]) + bm.allocate(s2) + assert s2.num_cached_tokens == 8 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 324c10a9c..55129ca24 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -121,7 +121,10 @@ def _prefill(self, scheduler, seq): def _output(self, seq_id, tokens): return ScheduledBatchOutput( - token_ids={seq_id: tuple(tokens)}, draft_token_ids=None + token_ids={seq_id: tuple(tokens)}, + num_rejected=None, + num_bonus=None, + draft_token_ids=None, ) def test_appends_token(self, scheduler, seq_factory): @@ -166,7 +169,12 @@ def test_stop_token_ids(self, seq_factory): sched.schedule() finished = sched.postprocess( list(sched.running), - ScheduledBatchOutput(token_ids={seq.id: (99,)}, draft_token_ids=None), + ScheduledBatchOutput( + token_ids={seq.id: (99,)}, + num_rejected=None, + num_bonus=None, + draft_token_ids=None, + ), ) assert len(finished) == 1 assert "stop_99" in finished[0].leave_reason