Skip to content

feat: support Prefix Caching for vLLM#252

Closed
qinganrice wants to merge 1 commit intoovg-project:mainfrom
qinganrice:feature/prefix_caching
Closed

feat: support Prefix Caching for vLLM#252
qinganrice wants to merge 1 commit intoovg-project:mainfrom
qinganrice:feature/prefix_caching

Conversation

@qinganrice
Copy link
Contributor

@qinganrice qinganrice commented Feb 24, 2026

Add Prefix Caching Support for vLLM

Summary

Implements to support prefix caching for vLLM, enabling KV cache block reuse across requests that share common prefixes. Includes LRU eviction with proper refcount synchronization to prevent memory leaks.

Implementation

Core Changes

New Files:

  • kvcached/prefix_cache_manager.py - PrefixCacheManager with LRU eviction logic
  • benchmarks/bench_prefix_cache/ - Comprehensive test suite (3 tests, helpers, documentation)

Modified Files:

  • kvcached/kv_cache_manager.py - Added refcount synchronization between physical blocks and cache entries
  • kvcached/integration/vllm/patches.py - Integrated prefix cache into vLLM request processing
  • kvcached/integration/vllm/interfaces.py - Pass PREFIX_CACHE_MAX_SIZE to KVCacheManager
  • kvcached/utils.py - Added PREFIX_CACHE_MAX_SIZE configuration constant
  • benchmarks/simple_bench/start_server.sh - Enabled prefix caching flag

Testing

Model: meta-llama/Llama-3.2-1B

Test 1: Basic Prefix Cache

  • Config: Default cache (1000 blocks), ~500 token prefix
  • Result: 1.08x speedup on cached requests

Test 2: Long Prefix Speedup

  • Config: Default cache (1000 blocks), ~2000 token prefix with 40+ examples
  • Result: 1.2-1.5x speedup (higher with larger models)

Test 3: LRU Eviction

  • Config: Small cache (10 blocks), 4-5 distinct prefixes
  • Result: 20+ successful evictions, 51 expected failures (blocks in use), eviction working correctly

Configuration

Set before starting server:

  • export KVCACHED_PREFIX_CACHE_MAX_SIZE=10 # Default: 1000
  • export KVCACHED_LOG_LEVEL=DEBUG # Optional: see cache operations

Documentation

Complete test suite with:

  • 3 test scripts (basic, Long Prefix, LRU Eviction)
  • 3 automation scripts
  • README with usage, troubleshooting, examples

To Do List:

  • SGLang Support: Extend prefix cache implementation to SGLang framework
  • Unified Cache Manager: Merge kv_cache_manager.py and prefix_cache_manager.py into a single file that supports both vLLM and SGLang frameworks

@cui36
Copy link
Collaborator

cui36 commented Feb 24, 2026

Cool! SGLang support is on the way.

@cui36 cui36 self-assigned this Feb 24, 2026
@cui36
Copy link
Collaborator

cui36 commented Feb 24, 2026

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces prefix caching for vLLM, a significant feature that can improve performance for prompts with shared prefixes. The implementation is comprehensive, including a new PrefixCacheManager with LRU eviction, careful reference counting in KVCacheManager to prevent memory issues, and clean integration with vLLM's block pool via patching. The addition of a thorough benchmark suite with tests for basic functionality, high speedup scenarios, and LRU eviction is excellent and greatly increases confidence in the correctness of this complex feature. The code is well-structured and the logic, especially the dual reference counting system, appears sound. I have one minor suggestion to improve the robustness of a test script. Overall, this is a high-quality contribution.

Comment on lines +28 to +34
cd "$(dirname "${BASH_SOURCE[0]}")/../simple_bench"

bash start_server.sh vllm \
--venv-path ../../engine_integration/vllm-pip-venv \
--model meta-llama/Llama-3.2-1B \
--port 12346 \
--tp 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better robustness and maintainability, it's preferable to avoid changing the current working directory (cd) within scripts. Instead, you can define the script's directory and construct paths relative to it. This makes the script's behavior independent of where it's called from.

Suggested change
cd "$(dirname "${BASH_SOURCE[0]}")/../simple_bench"
bash start_server.sh vllm \
--venv-path ../../engine_integration/vllm-pip-venv \
--model meta-llama/Llama-3.2-1B \
--port 12346 \
--tp 1
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
bash "$SCRIPT_DIR/../simple_bench/start_server.sh" vllm \
--venv-path "$SCRIPT_DIR/../../engine_integration/vllm-pip-venv" \
--model meta-llama/Llama-3.2-1B \
--port 12346 \
--tp 1

# For now, kvcached only supports single group (group_id=0)
# Ignore kv_cache_group_ids parameter
if len(kv_cache_group_ids) > 1:
logger.warning(f"ElasticBlockPool only supports single KV cache group, "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't it suppot multiple kv cache groups?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lianghao208, since we only recently added support for the hybrid attention in the latest release, the prefix caching implementation in this PR only focuses on the single-group case. Support for multi-group could be added later. However, prefix caching will prevent KVCached from releasing memory after a request finishes, which conflicts with KVCached’s core design. We are currently investigating this and looking for a good balance to support both features. If you have any ideas or comments, please feel free to share. Thanks!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reply, if we simply store prefix cache by BlockHashWithGroupId, just like the implementation in vllm, is that viable or there are other concerns?

# vllm/v1/core/kv_cache_utils.py
# `BlockHashWithGroupId` combines a `BlockHash` with its KV cache group ID.
# It is represented as raw bytes for compactness and efficiency. The helper
# functions below pack/unpack the `BlockHash` and group id into/from the key.
BlockHashWithGroupId = NewType("BlockHashWithGroupId", bytes)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lianghao208, if you store prefix cache by BlockHashWithGroupid, KVCached will raise an error. I will sync this with the team and get back to you soon. Thanks for your patience.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that makes sense, I just noticed the patches.py already filter out the multiple kv cache groups condition in kvcache initialization.
We are currently investigating this and looking for a good balance to support both features that's our main concerns as well.

can_free_from_cache = True
if idx in self.prefix_cache_manager.block_to_hash:
block_hash = self.prefix_cache_manager.block_to_hash[idx]
can_free_from_cache = self.prefix_cache_manager.decrement_refcount(block_hash)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If can_free_from_cache is true, it means the refcount in this prefix cache block entry is already 0, and the block will be freed immediately. Is the _evict_lru() still necessary?

# Initialize refcount for newly allocated blocks
if self.enable_prefix_cache:
for idx in ret_index:
self.block_refcounts[idx] = self.block_refcounts.get(idx, 0) + 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a newly allocated block's refcount has already been increased in block_pool.cache_full_blocks() --> kv_cache_manager.cache_blocks().
The whole process:

kvcache_manager.allocate_slots()-->coordinator.allocate_new_blocks()-->block_pool.get_new_blocks()-->coordinator.cache_blocks()-->block_pool.cache_full_blocks()-->refcount++

Is this refcount increment duplicated?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @lianghao208 thanks for the great advice! And could we have a further discussion on our slack channel (https://join.slack.com/t/ovg-project/shared_invite/zt-3s9e1i5ad-Jpvl~1TfoxDbuTTVEftg5A) if that works for you?

@qinganrice qinganrice force-pushed the feature/prefix_caching branch from a5d7a2a to 37ef62c Compare March 13, 2026 00:37
@cui36 cui36 closed this Mar 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants