From 7edb0f66db7ed0588064939df88c95be96937fd7 Mon Sep 17 00:00:00 2001 From: valarLip <340077269@qq.com> Date: Tue, 24 Feb 2026 15:00:02 +0000 Subject: [PATCH] fix: resolve prefix caching crashes with MTP speculative decoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix GPU memory access fault caused by double conversion of block_tables in cached prefill path. kv_indices_generate_triton applies block_ratio internally, but was receiving already-converted block_tables (via block_tables_converted), causing indices to be multiplied by block_ratio twice (e.g. block_id*256 instead of block_id*16), exceeding KV cache bounds. Key changes: - Use raw block_tables for kv_indices generation in aiter_mla prefill - Route cached prefill through paged MLA attention (supports Q≠K) instead of flash_attn_varlen_func (requires Q==K) - Track has_cached flag through AttentionMetaData for path selection - Fix block_manager: hash table leak, can_allocate cache-hit accounting, can_append for multi-token decode, O(1) free block tracking - Add CacheStats to scheduler for prefix cache hit rate monitoring - Add comprehensive block_manager tests (119 passing) Verified: gsm8k 1319 samples, 95.83% accuracy, 0 GPU faults. --- atom/model_engine/block_manager.py | 95 +++++++------ atom/model_engine/scheduler.py | 73 +++++++++- atom/model_ops/attention_mla.py | 2 +- atom/model_ops/attentions/aiter_mla.py | 20 ++- atom/model_ops/attentions/backends.py | 31 ++++- atom/utils/forward_context.py | 3 + tests/test_block_manager.py | 172 +++++++++++++++++++++++ tests/test_prefix_cache_accuracy.py | 185 +++++++++++++++++++++++++ 8 files changed, 528 insertions(+), 53 deletions(-) create mode 100644 tests/test_prefix_cache_accuracy.py diff --git a/atom/model_engine/block_manager.py b/atom/model_engine/block_manager.py index 937130265..184539764 100644 --- a/atom/model_engine/block_manager.py +++ b/atom/model_engine/block_manager.py @@ -37,6 +37,7 @@ def __init__(self, config: Config): self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.hash_to_block_id: dict[int, int] = dict() self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.free_block_ids_set: set[int] = set(range(num_blocks)) self.used_block_ids: set[int] = set() self.enable_prefix_caching = config.enable_prefix_caching @@ -48,11 +49,23 @@ def compute_hash(cls, token_ids: list[int], prefix: int = -1): h.update(np.array(token_ids).tobytes()) return h.intdigest() + def _pop_free_block(self) -> int: + """Pop the next available free block id from the FIFO queue (lazy cleanup).""" + while self.free_block_ids: + block_id = self.free_block_ids.popleft() + if block_id in self.free_block_ids_set: + self.free_block_ids_set.discard(block_id) + return block_id + raise AssertionError("No free blocks available") + def _allocate_block(self, block_id: int) -> Block: block = self.blocks[block_id] assert block.ref_count == 0 + # Evict stale hash entry before resetting + if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id: + del self.hash_to_block_id[block.hash] block.reset() - self.free_block_ids.remove(block_id) + self.free_block_ids_set.discard(block_id) self.used_block_ids.add(block_id) return self.blocks[block_id] @@ -60,10 +73,28 @@ def _deallocate_block(self, block_id: int): assert self.blocks[block_id].ref_count == 0 self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) - # self.free_block_ids.appendleft(block_id) + self.free_block_ids_set.add(block_id) def can_allocate(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= seq.num_blocks + if not self.enable_prefix_caching: + return len(self.free_block_ids_set) >= seq.num_blocks + # Dry-run: count how many blocks would be cache hits + h = -1 + cache_miss = False + needed_free = 0 + for i in range(seq.num_blocks): + token_ids = seq.block(i) + h = ( + self.compute_hash(token_ids, h) + if len(token_ids) == self.block_size + else -1 + ) + block_id = self.hash_to_block_id.get(h, -1) + if block_id == -1 or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + if cache_miss: + needed_free += 1 + return len(self.free_block_ids_set) >= needed_free def allocate(self, seq: Sequence): assert not seq.block_table @@ -82,7 +113,7 @@ def allocate(self, seq: Sequence): if block_id == -1 or self.blocks[block_id].token_ids != token_ids: cache_miss = True if cache_miss: - block_id = self.free_block_ids[0] + block_id = self._pop_free_block() block = self._allocate_block(block_id) else: seq.num_cached_tokens += self.block_size @@ -105,12 +136,17 @@ def deallocate(self, seq: Sequence): seq.num_cached_tokens = 0 seq.block_table.clear() - def can_append(self, seq: Sequence) -> bool: - return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool: + seq_len = len(seq) + current_blocks = len(seq.block_table) + needed_blocks = ( + seq_len + num_new_tokens + self.block_size - 1 + ) // self.block_size + new_blocks_needed = max(0, needed_blocks - current_blocks) + return len(self.free_block_ids_set) >= new_blocks_needed def may_append(self, seq: Sequence, num_new_tokens: int = 1): block_table = seq.block_table - last_block = self.blocks[block_table[-1]] seq_len = len(seq) # Check if we need to allocate a new block # When len(seq) % block_size == 1, we need a new block for the next token @@ -118,42 +154,9 @@ def may_append(self, seq: Sequence, num_new_tokens: int = 1): if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1: needed_blocks = (seq_len + self.block_size - 1) // self.block_size while len(block_table) < needed_blocks: - # For block_size == 1, we need to update hash for each new block - # For block_size > 1, the previous block should have hash != -1 (unless it's the first block) - if self.block_size == 1: - # Allocate new block and update hash immediately (like allocate does for full blocks) - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - token_ids = [seq[-1]] - prefix = ( - self.blocks[block_table[-2]].hash - if len(block_table) > 1 - else -1 - ) - h = self.compute_hash(token_ids, prefix) - block.update(h, token_ids) - self.hash_to_block_id[h] = block_id - else: - # For block_size > 1, we only allocate new block when needed - # The hash will be updated when the block becomes full - block_id = self.free_block_ids[0] - block = self._allocate_block(block_id) - block_table.append(block_id) - last_block = block - elif seq_len % self.block_size == 0: - # Last block is now full, update its hash (similar to allocate) - # TODO: fix hash - token_ids = seq.block(seq.num_blocks - 1) - if len(token_ids) == self.block_size: - prefix = ( - self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 - ) - h = self.compute_hash(token_ids, prefix) - last_block.update(h, token_ids) - self.hash_to_block_id[h] = last_block.block_id - else: - pass - # Last block is not full and not at the boundary - # Hash remains -1 until block is full (consistent with allocate logic) - # assert last_block.hash == -1, last_block.block_id + # Decode-generated blocks: token not finalized yet (depends on + # sampling / speculative verification), so we cannot compute a + # correct hash here. Just allocate the block without hashing. + block_id = self._pop_free_block() + self._allocate_block(block_id) + block_table.append(block_id) diff --git a/atom/model_engine/scheduler.py b/atom/model_engine/scheduler.py index e2af320b7..af85f2aca 100644 --- a/atom/model_engine/scheduler.py +++ b/atom/model_engine/scheduler.py @@ -106,6 +106,70 @@ def _log(self) -> None: ) +class CacheStats: + """Tracks prefix caching hit statistics.""" + + __slots__ = ( + "_log_interval", + "total_requests", + "total_cached_tokens", + "total_full_tokens", + "_interval_requests", + "_interval_cached_tokens", + "_interval_full_tokens", + ) + + def __init__(self, log_interval: int = 100): + self._log_interval = log_interval + self.total_requests: int = 0 + self.total_cached_tokens: int = 0 + self.total_full_tokens: int = 0 + self._interval_requests: int = 0 + self._interval_cached_tokens: int = 0 + self._interval_full_tokens: int = 0 + + def update(self, num_cached_tokens: int, num_full_tokens: int) -> None: + """Record cache stats for one prefill sequence.""" + self.total_requests += 1 + self.total_cached_tokens += num_cached_tokens + self.total_full_tokens += num_full_tokens + self._interval_requests += 1 + self._interval_cached_tokens += num_cached_tokens + self._interval_full_tokens += num_full_tokens + + if self.total_requests % self._log_interval == 0: + self._log() + self._reset_interval() + + @property + def hit_rate(self) -> float: + if self.total_full_tokens == 0: + return 0.0 + return self.total_cached_tokens / self.total_full_tokens + + def _reset_interval(self) -> None: + self._interval_requests = 0 + self._interval_cached_tokens = 0 + self._interval_full_tokens = 0 + + def _log(self) -> None: + iv_rate = ( + self._interval_cached_tokens / self._interval_full_tokens + if self._interval_full_tokens > 0 + else 0.0 + ) + logger.info( + f"[Cache Stats Interval] Reqs: {self._interval_requests}, " + f"Cached/Total tokens: {self._interval_cached_tokens}/{self._interval_full_tokens}, " + f"Hit rate: {iv_rate:.2%}" + ) + logger.info( + f"[Cache Stats ] Reqs: {self.total_requests}, " + f"Cached/Total tokens: {self.total_cached_tokens}/{self.total_full_tokens}, " + f"Hit rate: {self.hit_rate:.2%}" + ) + + class ScheduledBatch: def __init__( self, @@ -233,6 +297,9 @@ def __init__(self, config: Config): self.spec_stats: Optional[SpecStats] = ( SpecStats(mtp_k=self.mtp_k) if self.use_spec else None ) + self.cache_stats: Optional[CacheStats] = ( + CacheStats() if config.enable_prefix_caching else None + ) def is_finished(self): return not self.waiting and not self.running @@ -270,6 +337,10 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: break num_seqs_prefill += 1 self.block_manager.allocate(seq) + # Recompute after allocate: num_cached_tokens may have increased + num_new_tokens = seq.num_tokens - seq.num_cached_tokens + if self.cache_stats: + self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens) num_batched_tokens += num_new_tokens seq.status = SequenceStatus.RUNNING seq.type = SequenceType.PREFILL @@ -303,7 +374,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]: num_seqs_decode = 0 while self.running and num_seqs_decode < self.max_num_seqs: seq = self.running.popleft() - while not self.block_manager.can_append(seq): + while not self.block_manager.can_append(seq, self.mtp_k + 1): if self.running: self.preempt(self.running.pop()) else: diff --git a/atom/model_ops/attention_mla.py b/atom/model_ops/attention_mla.py index c164f63fd..a7de78c4f 100644 --- a/atom/model_ops/attention_mla.py +++ b/atom/model_ops/attention_mla.py @@ -582,7 +582,7 @@ def forward( kv_cache_data = forward_context.kv_cache_data kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache - if context.is_prefill and not use_prefill_mla: + if context.is_prefill and not use_prefill_mla and not attn_metadata.has_cached: prefill_q = self.q_proj(q, x_scale=q_scale).view( -1, self.num_heads, self.qk_head_dim ) diff --git a/atom/model_ops/attentions/aiter_mla.py b/atom/model_ops/attentions/aiter_mla.py index 356bac5e0..225a4928e 100644 --- a/atom/model_ops/attentions/aiter_mla.py +++ b/atom/model_ops/attentions/aiter_mla.py @@ -242,12 +242,30 @@ def prepare_prefill(self, batch: ScheduledBatch): sum_scheduled_tokens + 1 ) - if hasattr(self.model_runner, "drafter"): + if hasattr(self.model_runner, "drafter") or attn_metadata.has_cached: attn_metadata.kv_indices = var["kv_indices"].gpu attn_metadata.kv_indptr = var["kv_indptr"].gpu[: bs + 1] + attn_metadata.kv_indptr[0] = 0 attn_metadata.kv_indptr[1 : bs + 1] = torch.cumsum( attn_metadata.context_lens, 0 ) + attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs] + + if attn_metadata.has_cached: + # Ensure raw block_tables (not block_tables_converted) are on GPU. + # kv_indices_generate_triton applies block_ratio internally, so we + # must pass raw model-level block_ids to avoid double conversion. + if not batch.block_tables: + self.prepare_block_tables(batch) + raw_block_tables = var["block_tables"].copy_to_gpu(bs) + max_seqlen_k = int(attn_metadata.context_lens.max().item()) + kv_indices_generate_triton( + raw_block_tables, + attn_metadata.kv_indices, + attn_metadata.kv_indptr, + self.block_ratio, + max_seqlen_k, + ) return attn_metadata, positions diff --git a/atom/model_ops/attentions/backends.py b/atom/model_ops/attentions/backends.py index c02b7a707..544af5ee0 100644 --- a/atom/model_ops/attentions/backends.py +++ b/atom/model_ops/attentions/backends.py @@ -1,11 +1,14 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Generic, Optional, Type, TypeVar import torch from atom.model_engine.scheduler import ScheduledBatch + +logger = logging.getLogger("atom") from atom.model_ops.attention_mla import MLAModules from atom.utils import CpuGpuBuffer from atom.utils.block_convert import block_table_convert_triton @@ -141,18 +144,23 @@ def prepare_prefill(self, batch: ScheduledBatch): sum_scheduled_tokens = batch.total_tokens_num_prefill var = self.model_runner.forward_vars positions = [] + cu_seqlens_q = [0] cu_seqlens_k = [0] max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] + has_cached = False # seqs = list(batch.seqs.values()) # seqs = seqs[:bs] for i in range(bs): seqlen = batch.context_lens[i] cached_seqlen = batch.num_cached_tokens[i] + if cached_seqlen > 0: + has_cached = True positions.extend(list(range(cached_seqlen, seqlen))) seqlen_q = seqlen - cached_seqlen seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) @@ -166,17 +174,29 @@ def prepare_prefill(self, batch: ScheduledBatch): ) // self.model_runner.block_size last_block_tokens = batch.last_block_num_tokens[i] block_table = batch.block_tables[i] - for i in range(num_cached_blocks, num_blocks): - start = block_table[i] * self.model_runner.block_size - if i != num_blocks - 1: + for blk_idx in range(num_cached_blocks, num_blocks): + start = block_table[blk_idx] * self.model_runner.block_size + if blk_idx != num_blocks - 1: end = start + self.model_runner.block_size else: end = start + last_block_tokens slot_mapping.extend(list(range(start, end))) - if cu_seqlens_k[-1] > batch.total_tokens_num: # prefix cache + if has_cached: self.prepare_block_tables(batch) + # Validate metadata consistency + assert ( + len(positions) == sum_scheduled_tokens + ), f"positions length {len(positions)} != sum_scheduled_tokens {sum_scheduled_tokens}" + if batch.block_tables: + assert ( + len(slot_mapping) == sum_scheduled_tokens + ), f"slot_mapping length {len(slot_mapping)} != sum_scheduled_tokens {sum_scheduled_tokens}" + assert ( + cu_seqlens_q[-1] == sum_scheduled_tokens + ), f"cu_seqlens_q[-1]={cu_seqlens_q[-1]} != sum_scheduled_tokens={sum_scheduled_tokens}" var["positions"].np[:sum_scheduled_tokens] = positions var["slot_mapping"].np[: len(slot_mapping)] = slot_mapping + var["cu_seqlens_q"].np[: bs + 1] = cu_seqlens_q cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True) var["context_lens"].np[:bs] = batch.context_lens[:bs] min_seqlen_q = 0 @@ -186,6 +206,8 @@ def prepare_prefill(self, batch: ScheduledBatch): ("slot_mapping", len(slot_mapping)), ("context_lens", bs), ] + if has_cached: + vars_used.append(("block_tables", bs)) ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used} if self.block_ratio > 1 and "block_tables" in ctx: @@ -202,6 +224,7 @@ def prepare_prefill(self, batch: ScheduledBatch): max_seqlen_k=max_seqlen_k, min_seqlen_q=min_seqlen_q, dropout_p=dropout_p, + has_cached=has_cached, **ctx, ) positions = var["positions"].copy_to_gpu(sum_scheduled_tokens) diff --git a/atom/utils/forward_context.py b/atom/utils/forward_context.py index 341cef129..7ccd82b1d 100644 --- a/atom/utils/forward_context.py +++ b/atom/utils/forward_context.py @@ -186,6 +186,7 @@ class AttentionMetaData: reduce_partial_map: Optional[torch.Tensor] = None block_tables_converted: Optional[torch.Tensor] = None + has_cached: bool = False def __init__( self, @@ -213,7 +214,9 @@ def __init__( block_tables_converted: Optional[torch.Tensor] = None, sparse_cu_seqlens_q: Optional[torch.Tensor] = None, token_to_seq_idxs: Optional[torch.Tensor] = None, + has_cached: bool = False, ): + self.has_cached = has_cached self.cu_seqlens_q = cu_seqlens_q self.cu_seqlens_k = cu_seqlens_k self.max_seqlen_q = max_seqlen_q diff --git a/tests/test_block_manager.py b/tests/test_block_manager.py index 9dfb5d484..60a6c1061 100644 --- a/tests/test_block_manager.py +++ b/tests/test_block_manager.py @@ -173,3 +173,175 @@ def test_block_size_1(self, seq_factory): seq.append_token(3) bm.may_append(seq) assert len(seq.block_table) == 3 + + +# ── Prefix caching: can_allocate with cache hits ───────────────────────── + + +class TestCanAllocateWithPrefixCaching: + def test_can_allocate_accounts_for_cache_hits(self, seq_factory): + """With 3 blocks total, allocate 2-block seq, deallocate, then a new + 2-block seq sharing block 1 should need only 1 free block.""" + cfg = MockConfig( + num_kvcache_blocks=3, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + bm.deallocate(s1) # blocks freed, hashes retained + + # Use up 2 of the 3 free blocks + filler = seq_factory([50, 51, 52, 53, 60, 61, 62, 63]) + bm.allocate(filler) + # Only 1 free block left; s2 needs 2 blocks but first is cached + s2 = seq_factory([1, 2, 3, 4, 9, 10, 11, 12]) + assert bm.can_allocate(s2) + + def test_can_allocate_no_false_positive(self, seq_factory): + """can_allocate should return False when even with cache hits + there aren't enough free blocks.""" + cfg = MockConfig( + num_kvcache_blocks=2, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + # 0 free blocks; new seq shares prefix but needs 1 new block + s2 = seq_factory([1, 2, 3, 4, 9, 10, 11, 12]) + assert not bm.can_allocate(s2) + + +# ── Hash table cleanup ─────────────────────────────────────────────────── + + +class TestHashTableCleanup: + def test_stale_hash_entries_evicted_on_reuse(self, seq_factory): + """When a cached block is reused for a different hash, the old + hash_to_block_id entry should be cleaned up.""" + cfg = MockConfig( + num_kvcache_blocks=2, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + h1 = bm.blocks[s1.block_table[0]].hash + bm.deallocate(s1) + + # Allocate with completely different tokens — should overwrite blocks + s2 = seq_factory([90, 91, 92, 93, 94, 95, 96, 97]) + bm.allocate(s2) + # Old hash should no longer point to a valid block + assert bm.hash_to_block_id.get(h1) != s2.block_table[0] + + def test_hash_table_bounded_growth(self, seq_factory): + """hash_to_block_id should not grow beyond num_kvcache_blocks.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + for i in range(20): + tokens = list(range(i * 4, i * 4 + 4)) + seq = seq_factory(tokens) + if bm.can_allocate(seq): + bm.allocate(seq) + bm.deallocate(seq) + assert len(bm.hash_to_block_id) <= cfg.num_kvcache_blocks + + +# ── can_append with multi-token decode (speculative decoding) ──────────── + + +class TestCanAppendMultiToken: + def test_can_append_multi_token_within_block(self, block_manager, seq_factory): + """Appending 3 tokens that stay within the current block.""" + seq = seq_factory([1]) + block_manager.allocate(seq) + seq.append_token(2) + seq.append_token(3) + assert block_manager.can_append(seq, num_new_tokens=3) + + def test_can_append_multi_token_crossing_boundary(self, seq_factory): + """block_size=4, seq_len=14 (3.5 blocks=4 blocks allocated), + appending 5 tokens crosses into block 5 — needs 1 new block.""" + cfg = MockConfig(num_kvcache_blocks=6, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory(list(range(14))) + bm.allocate(seq) + # seq_len=14, 4 blocks. Appending 5 tokens: positions 14..18 → need block 5 + for t in range(14, 19): + seq.append_token(t) + assert bm.can_append(seq, num_new_tokens=5) + + def test_cannot_append_multi_token_no_free(self, seq_factory): + """block_size=4, 4 blocks total, seq fills 4 blocks (16 tokens), + appending 5 tokens needs 2 new blocks but only 0 free.""" + cfg = MockConfig(num_kvcache_blocks=4, kv_cache_block_size=4) + bm = BlockManager(cfg) + seq = seq_factory(list(range(14))) + bm.allocate(seq) + for t in range(14, 19): + seq.append_token(t) + assert not bm.can_append(seq, num_new_tokens=5) + + +# ── Prefix caching + preemption ────────────────────────────────────────── + + +class TestPrefixCachingPreemption: + def test_preempt_and_reschedule_reuses_cache(self, seq_factory): + """Preempted sequence re-discovers cache hits on re-allocation.""" + cfg = MockConfig( + num_kvcache_blocks=10, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1) + # Simulate preemption + bm.deallocate(s1) + assert s1.num_cached_tokens == 0 + assert s1.block_table == [] + + # Re-allocate — should get cache hits on both blocks + s1_retry = seq_factory([1, 2, 3, 4, 5, 6, 7, 8]) + bm.allocate(s1_retry) + assert s1_retry.num_cached_tokens == 8 # both blocks cached + + +# ── Edge cases ─────────────────────────────────────────────────────────── + + +class TestPrefixCachingEdgeCases: + def test_single_token_no_cache(self, seq_factory): + """Single token seq (shorter than block_size) — hash is -1, no caching.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([42]) + bm.allocate(s1) + bm.deallocate(s1) + s2 = seq_factory([42]) + bm.allocate(s2) + # Partial block → hash is -1 → no caching + assert s2.num_cached_tokens == 0 + + def test_exact_block_size_fully_cached(self, seq_factory): + """Sequence with exactly block_size tokens — fully cached on reuse.""" + cfg = MockConfig( + num_kvcache_blocks=4, kv_cache_block_size=4, enable_prefix_caching=True + ) + bm = BlockManager(cfg) + s1 = seq_factory([1, 2, 3, 4]) + bm.allocate(s1) + bm.deallocate(s1) + s2 = seq_factory([1, 2, 3, 4]) + bm.allocate(s2) + assert s2.num_cached_tokens == 4 + + def test_free_block_ids_set_consistent(self, block_manager, seq_factory): + """free_block_ids_set stays consistent through allocate/deallocate.""" + s1 = seq_factory([1, 2, 3, 4]) + block_manager.allocate(s1) + initial_free = len(block_manager.free_block_ids_set) + block_manager.deallocate(s1) + assert len(block_manager.free_block_ids_set) == initial_free + 1 diff --git a/tests/test_prefix_cache_accuracy.py b/tests/test_prefix_cache_accuracy.py new file mode 100644 index 000000000..e10d97cf2 --- /dev/null +++ b/tests/test_prefix_cache_accuracy.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +""" +Test prefix caching accuracy with high cache-hit workloads. + +Sends batches of requests that share long common prefixes, then verifies: +1. Responses are correct (math problems with known answers) +2. Cache hit rate is high (visible in server logs) +3. Repeated identical requests produce consistent results +""" + +import argparse +import concurrent.futures +import json +import re +import sys +import time + +import requests + +BASE_URL = "http://localhost:8000" + +# Long shared prefix: 5-shot math examples (~2000 tokens) +MATH_PREFIX = """You are a precise math assistant. Solve each problem step by step and give the final numerical answer after ####. + +Question: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today? +Answer: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. #### 6 + +Question: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot? +Answer: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. #### 5 + +Question: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total? +Answer: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. #### 39 + +Question: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny? +Answer: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. #### 8 + +Question: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now? +Answer: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 2 + 2 = 4 more toys. 5 + 4 = 9. #### 9 + +""" + +# Test questions with known answers +TEST_QUESTIONS = [ + ("A farmer has 17 sheep. He buys 5 more and then sells 8. How many sheep does he have?", 14), + ("A store had 45 apples. They sold 12 in the morning and 18 in the afternoon. How many apples are left?", 15), + ("Tom has 8 marbles. Jerry has 3 times as many. How many marbles do they have together?", 32), + ("A classroom has 6 rows of desks with 5 desks in each row. If 7 desks are removed, how many remain?", 23), + ("Sarah baked 24 cookies. She gave 1/3 to her neighbor and ate 4 herself. How many cookies does she have left?", 12), + ("A train travels 60 miles per hour for 3 hours, then 40 miles per hour for 2 hours. What is the total distance?", 260), + ("Mike has 50 dollars. He spends 15 dollars on lunch and 20 dollars on a book. How much money does he have left?", 15), + ("A garden has 9 rose bushes. Each bush has 12 roses. If 25 roses are picked, how many roses remain?", 83), + ("Lisa read 35 pages on Monday and twice as many on Tuesday. How many pages did she read in total?", 105), + ("A box contains 100 balls. 40 are red, 35 are blue, and the rest are green. How many green balls are there?", 25), +] + + +def extract_answer(text: str): + """Extract numerical answer after the FIRST #### marker.""" + match = re.search(r"####\s*(-?\d+(?:\.\d+)?)", text) + if match: + return float(match.group(1)) + return None + + +def get_model_name(base_url: str) -> str: + """Get the model name from the server.""" + resp = requests.get(f"{base_url}/v1/models", timeout=5) + resp.raise_for_status() + return resp.json()["data"][0]["id"] + + +def send_completion(prompt: str, max_tokens: int = 256, base_url: str = BASE_URL, model: str = "") -> str: + """Send a completion request to the server.""" + resp = requests.post( + f"{base_url}/v1/completions", + json={ + "model": model, + "prompt": prompt, + "max_tokens": max_tokens, + "temperature": 0.0, + }, + timeout=120, + ) + resp.raise_for_status() + return resp.json()["choices"][0]["text"] + + +def run_batch(questions, prefix, base_url=BASE_URL, model="", label=""): + """Run a batch of questions and return (correct, total, results).""" + results = [] + + def ask(q_and_a): + question, expected = q_and_a + prompt = prefix + f"Question: {question}\nAnswer:" + try: + response = send_completion(prompt, base_url=base_url, model=model) + answer = extract_answer(response) + correct = answer is not None and abs(answer - expected) < 0.01 + return (question, expected, answer, correct, response.strip()) + except Exception as e: + return (question, expected, None, False, f"ERROR: {e}") + + with concurrent.futures.ThreadPoolExecutor(max_workers=8) as pool: + results = list(pool.map(ask, questions)) + + num_correct = sum(1 for r in results if r[3]) + return num_correct, len(results), results + + +def main(): + parser = argparse.ArgumentParser(description="Test prefix cache accuracy") + parser.add_argument("--rounds", type=int, default=3, help="Number of rounds to repeat") + parser.add_argument("--base-url", type=str, default=BASE_URL) + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + base_url = args.base_url + + # Health check + try: + r = requests.get(f"{base_url}/health", timeout=5) + r.raise_for_status() + except Exception as e: + print(f"Server not reachable at {base_url}: {e}") + sys.exit(1) + + model = get_model_name(base_url) + print(f"=== Prefix Cache Accuracy Test ===") + print(f"Server: {base_url}") + print(f"Model: {model}") + print(f"Questions per round: {len(TEST_QUESTIONS)}") + print(f"Rounds: {args.rounds}") + print(f"Shared prefix length: ~{len(MATH_PREFIX)} chars") + print() + + all_round_results = [] + + for round_num in range(1, args.rounds + 1): + t0 = time.time() + correct, total, results = run_batch(TEST_QUESTIONS, MATH_PREFIX, base_url=base_url, model=model, label=f"Round {round_num}") + elapsed = time.time() - t0 + accuracy = 100.0 * correct / total + all_round_results.append((correct, total, accuracy, elapsed)) + + print(f"Round {round_num}: {correct}/{total} correct ({accuracy:.1f}%) in {elapsed:.1f}s") + + if args.verbose: + for q, expected, got, ok, resp in results: + status = "OK" if ok else "WRONG" + print(f" [{status}] {q[:60]}... expected={expected} got={got}") + if not ok: + # Show first 200 chars of response for debugging + print(f" response: {resp[:200]}") + print() + + print() + print("=== Summary ===") + total_correct = sum(r[0] for r in all_round_results) + total_questions = sum(r[1] for r in all_round_results) + overall_accuracy = 100.0 * total_correct / total_questions + + # Check consistency: same questions should give same answers across rounds + print(f"Overall: {total_correct}/{total_questions} ({overall_accuracy:.1f}%)") + for i, (c, t, a, e) in enumerate(all_round_results, 1): + cache_note = "(cold)" if i == 1 else "(cache warm)" + print(f" Round {i}: {c}/{t} ({a:.1f}%) {e:.1f}s {cache_note}") + + # Verify rounds 2+ should be faster (cache hits) + if args.rounds >= 2: + r1_time = all_round_results[0][3] + r2_time = all_round_results[1][3] + speedup = r1_time / r2_time if r2_time > 0 else 0 + print(f"\n Speedup round 2 vs round 1: {speedup:.2f}x") + + # Pass/fail + if overall_accuracy >= 80.0: + print(f"\nPASS: accuracy {overall_accuracy:.1f}% >= 80%") + return 0 + else: + print(f"\nFAIL: accuracy {overall_accuracy:.1f}% < 80%") + return 1 + + +if __name__ == "__main__": + sys.exit(main())