Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 49 additions & 46 deletions atom/model_engine/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, config: Config):
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.free_block_ids_set: set[int] = set(range(num_blocks))
self.used_block_ids: set[int] = set()
self.enable_prefix_caching = config.enable_prefix_caching

Expand All @@ -48,22 +49,52 @@ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h.update(np.array(token_ids).tobytes())
return h.intdigest()

def _pop_free_block(self) -> int:
"""Pop the next available free block id from the FIFO queue (lazy cleanup)."""
while self.free_block_ids:
block_id = self.free_block_ids.popleft()
if block_id in self.free_block_ids_set:
self.free_block_ids_set.discard(block_id)
return block_id
raise AssertionError("No free blocks available")

def _allocate_block(self, block_id: int) -> Block:
block = self.blocks[block_id]
assert block.ref_count == 0
# Evict stale hash entry before resetting
if block.hash != -1 and self.hash_to_block_id.get(block.hash) == block_id:
del self.hash_to_block_id[block.hash]
block.reset()
self.free_block_ids.remove(block_id)
self.free_block_ids_set.discard(block_id)
self.used_block_ids.add(block_id)
return self.blocks[block_id]

def _deallocate_block(self, block_id: int):
assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id)
# self.free_block_ids.appendleft(block_id)
self.free_block_ids_set.add(block_id)

def can_allocate(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= seq.num_blocks
if not self.enable_prefix_caching:
return len(self.free_block_ids_set) >= seq.num_blocks
# Dry-run: count how many blocks would be cache hits
h = -1
cache_miss = False
needed_free = 0
for i in range(seq.num_blocks):
token_ids = seq.block(i)
h = (
self.compute_hash(token_ids, h)
if len(token_ids) == self.block_size
else -1
)
block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
needed_free += 1
return len(self.free_block_ids_set) >= needed_free

def allocate(self, seq: Sequence):
assert not seq.block_table
Expand All @@ -82,7 +113,7 @@ def allocate(self, seq: Sequence):
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True
if cache_miss:
block_id = self.free_block_ids[0]
block_id = self._pop_free_block()
block = self._allocate_block(block_id)
else:
seq.num_cached_tokens += self.block_size
Expand All @@ -105,55 +136,27 @@ def deallocate(self, seq: Sequence):
seq.num_cached_tokens = 0
seq.block_table.clear()

def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def can_append(self, seq: Sequence, num_new_tokens: int = 1) -> bool:
seq_len = len(seq)
current_blocks = len(seq.block_table)
needed_blocks = (
seq_len + num_new_tokens + self.block_size - 1
) // self.block_size
new_blocks_needed = max(0, needed_blocks - current_blocks)
return len(self.free_block_ids_set) >= new_blocks_needed

def may_append(self, seq: Sequence, num_new_tokens: int = 1):
block_table = seq.block_table
last_block = self.blocks[block_table[-1]]
seq_len = len(seq)
# Check if we need to allocate a new block
# When len(seq) % block_size == 1, we need a new block for the next token
# When block_size == 1, every token needs a new block
if 0 < seq_len % self.block_size <= num_new_tokens or self.block_size == 1:
needed_blocks = (seq_len + self.block_size - 1) // self.block_size
while len(block_table) < needed_blocks:
# For block_size == 1, we need to update hash for each new block
# For block_size > 1, the previous block should have hash != -1 (unless it's the first block)
if self.block_size == 1:
# Allocate new block and update hash immediately (like allocate does for full blocks)
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
token_ids = [seq[-1]]
prefix = (
self.blocks[block_table[-2]].hash
if len(block_table) > 1
else -1
)
h = self.compute_hash(token_ids, prefix)
block.update(h, token_ids)
self.hash_to_block_id[h] = block_id
else:
# For block_size > 1, we only allocate new block when needed
# The hash will be updated when the block becomes full
block_id = self.free_block_ids[0]
block = self._allocate_block(block_id)
block_table.append(block_id)
last_block = block
elif seq_len % self.block_size == 0:
# Last block is now full, update its hash (similar to allocate)
# TODO: fix hash
token_ids = seq.block(seq.num_blocks - 1)
if len(token_ids) == self.block_size:
prefix = (
self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
)
h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id
else:
pass
# Last block is not full and not at the boundary
# Hash remains -1 until block is full (consistent with allocate logic)
# assert last_block.hash == -1, last_block.block_id
# Decode-generated blocks: token not finalized yet (depends on
# sampling / speculative verification), so we cannot compute a
# correct hash here. Just allocate the block without hashing.
block_id = self._pop_free_block()
self._allocate_block(block_id)
block_table.append(block_id)
73 changes: 72 additions & 1 deletion atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,70 @@ def _log(self) -> None:
)


class CacheStats:
"""Tracks prefix caching hit statistics."""

__slots__ = (
"_log_interval",
"total_requests",
"total_cached_tokens",
"total_full_tokens",
"_interval_requests",
"_interval_cached_tokens",
"_interval_full_tokens",
)

def __init__(self, log_interval: int = 100):
self._log_interval = log_interval
self.total_requests: int = 0
self.total_cached_tokens: int = 0
self.total_full_tokens: int = 0
self._interval_requests: int = 0
self._interval_cached_tokens: int = 0
self._interval_full_tokens: int = 0

def update(self, num_cached_tokens: int, num_full_tokens: int) -> None:
"""Record cache stats for one prefill sequence."""
self.total_requests += 1
self.total_cached_tokens += num_cached_tokens
self.total_full_tokens += num_full_tokens
self._interval_requests += 1
self._interval_cached_tokens += num_cached_tokens
self._interval_full_tokens += num_full_tokens

if self.total_requests % self._log_interval == 0:
self._log()
self._reset_interval()

@property
def hit_rate(self) -> float:
if self.total_full_tokens == 0:
return 0.0
return self.total_cached_tokens / self.total_full_tokens

def _reset_interval(self) -> None:
self._interval_requests = 0
self._interval_cached_tokens = 0
self._interval_full_tokens = 0

def _log(self) -> None:
iv_rate = (
self._interval_cached_tokens / self._interval_full_tokens
if self._interval_full_tokens > 0
else 0.0
)
logger.info(
f"[Cache Stats Interval] Reqs: {self._interval_requests}, "
f"Cached/Total tokens: {self._interval_cached_tokens}/{self._interval_full_tokens}, "
f"Hit rate: {iv_rate:.2%}"
)
logger.info(
f"[Cache Stats ] Reqs: {self.total_requests}, "
f"Cached/Total tokens: {self.total_cached_tokens}/{self.total_full_tokens}, "
f"Hit rate: {self.hit_rate:.2%}"
)


class ScheduledBatch:
def __init__(
self,
Expand Down Expand Up @@ -233,6 +297,9 @@ def __init__(self, config: Config):
self.spec_stats: Optional[SpecStats] = (
SpecStats(mtp_k=self.mtp_k) if self.use_spec else None
)
self.cache_stats: Optional[CacheStats] = (
CacheStats() if config.enable_prefix_caching else None
)

def is_finished(self):
return not self.waiting and not self.running
Expand Down Expand Up @@ -270,6 +337,10 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
break
num_seqs_prefill += 1
self.block_manager.allocate(seq)
# Recompute after allocate: num_cached_tokens may have increased
num_new_tokens = seq.num_tokens - seq.num_cached_tokens
if self.cache_stats:
self.cache_stats.update(seq.num_cached_tokens, seq.num_tokens)
num_batched_tokens += num_new_tokens
seq.status = SequenceStatus.RUNNING
seq.type = SequenceType.PREFILL
Expand Down Expand Up @@ -303,7 +374,7 @@ def schedule(self) -> tuple[ScheduledBatch, dict[int, Sequence]]:
num_seqs_decode = 0
while self.running and num_seqs_decode < self.max_num_seqs:
seq = self.running.popleft()
while not self.block_manager.can_append(seq):
while not self.block_manager.can_append(seq, self.mtp_k + 1):
if self.running:
self.preempt(self.running.pop())
else:
Expand Down
2 changes: 1 addition & 1 deletion atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def forward(
kv_cache_data = forward_context.kv_cache_data
kv_cache = kv_cache_data[f"layer_{self.layer_num}"].k_cache

if context.is_prefill and not use_prefill_mla:
if context.is_prefill and not use_prefill_mla and not attn_metadata.has_cached:
prefill_q = self.q_proj(q, x_scale=q_scale).view(
-1, self.num_heads, self.qk_head_dim
)
Expand Down
20 changes: 19 additions & 1 deletion atom/model_ops/attentions/aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,30 @@ def prepare_prefill(self, batch: ScheduledBatch):
sum_scheduled_tokens + 1
)

if hasattr(self.model_runner, "drafter"):
if hasattr(self.model_runner, "drafter") or attn_metadata.has_cached:
attn_metadata.kv_indices = var["kv_indices"].gpu
attn_metadata.kv_indptr = var["kv_indptr"].gpu[: bs + 1]
attn_metadata.kv_indptr[0] = 0
attn_metadata.kv_indptr[1 : bs + 1] = torch.cumsum(
attn_metadata.context_lens, 0
)
attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs]

if attn_metadata.has_cached:
# Ensure raw block_tables (not block_tables_converted) are on GPU.
# kv_indices_generate_triton applies block_ratio internally, so we
# must pass raw model-level block_ids to avoid double conversion.
if not batch.block_tables:
self.prepare_block_tables(batch)
raw_block_tables = var["block_tables"].copy_to_gpu(bs)
max_seqlen_k = int(attn_metadata.context_lens.max().item())
kv_indices_generate_triton(
raw_block_tables,
attn_metadata.kv_indices,
attn_metadata.kv_indptr,
self.block_ratio,
max_seqlen_k,
)

return attn_metadata, positions

Expand Down
31 changes: 27 additions & 4 deletions atom/model_ops/attentions/backends.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, Optional, Type, TypeVar

import torch
from atom.model_engine.scheduler import ScheduledBatch

logger = logging.getLogger("atom")
from atom.model_ops.attention_mla import MLAModules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from atom.utils import CpuGpuBuffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

from atom.utils.block_convert import block_table_convert_triton
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E402> reported by reviewdog 🐶
Module level import not at top of file

Expand Down Expand Up @@ -141,18 +144,23 @@ def prepare_prefill(self, batch: ScheduledBatch):
sum_scheduled_tokens = batch.total_tokens_num_prefill
var = self.model_runner.forward_vars
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
has_cached = False
# seqs = list(batch.seqs.values())
# seqs = seqs[:bs]
for i in range(bs):
seqlen = batch.context_lens[i]
cached_seqlen = batch.num_cached_tokens[i]
if cached_seqlen > 0:
has_cached = True
positions.extend(list(range(cached_seqlen, seqlen)))
seqlen_q = seqlen - cached_seqlen
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
Expand All @@ -166,17 +174,29 @@ def prepare_prefill(self, batch: ScheduledBatch):
) // self.model_runner.block_size
last_block_tokens = batch.last_block_num_tokens[i]
block_table = batch.block_tables[i]
for i in range(num_cached_blocks, num_blocks):
start = block_table[i] * self.model_runner.block_size
if i != num_blocks - 1:
for blk_idx in range(num_cached_blocks, num_blocks):
start = block_table[blk_idx] * self.model_runner.block_size
if blk_idx != num_blocks - 1:
end = start + self.model_runner.block_size
else:
end = start + last_block_tokens
slot_mapping.extend(list(range(start, end)))
if cu_seqlens_k[-1] > batch.total_tokens_num: # prefix cache
if has_cached:
self.prepare_block_tables(batch)
# Validate metadata consistency
assert (
len(positions) == sum_scheduled_tokens
), f"positions length {len(positions)} != sum_scheduled_tokens {sum_scheduled_tokens}"
if batch.block_tables:
assert (
len(slot_mapping) == sum_scheduled_tokens
), f"slot_mapping length {len(slot_mapping)} != sum_scheduled_tokens {sum_scheduled_tokens}"
assert (
cu_seqlens_q[-1] == sum_scheduled_tokens
), f"cu_seqlens_q[-1]={cu_seqlens_q[-1]} != sum_scheduled_tokens={sum_scheduled_tokens}"
var["positions"].np[:sum_scheduled_tokens] = positions
var["slot_mapping"].np[: len(slot_mapping)] = slot_mapping
var["cu_seqlens_q"].np[: bs + 1] = cu_seqlens_q
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True)
var["context_lens"].np[:bs] = batch.context_lens[:bs]
min_seqlen_q = 0
Expand All @@ -186,6 +206,8 @@ def prepare_prefill(self, batch: ScheduledBatch):
("slot_mapping", len(slot_mapping)),
("context_lens", bs),
]
if has_cached:
vars_used.append(("block_tables", bs))

ctx = {el: var[el].copy_to_gpu(num) for el, num in vars_used}
if self.block_ratio > 1 and "block_tables" in ctx:
Expand All @@ -202,6 +224,7 @@ def prepare_prefill(self, batch: ScheduledBatch):
max_seqlen_k=max_seqlen_k,
min_seqlen_q=min_seqlen_q,
dropout_p=dropout_p,
has_cached=has_cached,
**ctx,
)
positions = var["positions"].copy_to_gpu(sum_scheduled_tokens)
Expand Down
3 changes: 3 additions & 0 deletions atom/utils/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ class AttentionMetaData:
reduce_partial_map: Optional[torch.Tensor] = None

block_tables_converted: Optional[torch.Tensor] = None
has_cached: bool = False

def __init__(
self,
Expand Down Expand Up @@ -213,7 +214,9 @@ def __init__(
block_tables_converted: Optional[torch.Tensor] = None,
sparse_cu_seqlens_q: Optional[torch.Tensor] = None,
token_to_seq_idxs: Optional[torch.Tensor] = None,
has_cached: bool = False,
):
self.has_cached = has_cached
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
self.max_seqlen_q = max_seqlen_q
Expand Down
Loading
Loading