Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.git
__pycache__
*.pyc
*.egg-info
build/
dist/
42 changes: 42 additions & 0 deletions .github/workflows/sync-upstream.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
name: Sync upstream main

on:
schedule:
# Run nightly at 06:00 UTC (midnight CST)
- cron: '0 6 * * *'
workflow_dispatch: # Allow manual trigger

jobs:
sync:
runs-on: ubuntu-latest
steps:
- name: Checkout fork
uses: actions/checkout@v4
with:
ref: main
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}

- name: Add upstream remote
run: git remote add upstream https://github.com/ROCm/ATOM.git

- name: Fetch upstream
run: git fetch upstream main

- name: Check for new commits
id: check
run: |
BEHIND=$(git rev-list --count HEAD..upstream/main)
echo "behind=$BEHIND" >> "$GITHUB_OUTPUT"
echo "Fork is $BEHIND commit(s) behind upstream"

- name: Merge upstream
if: steps.check.outputs.behind != '0'
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git merge upstream/main --no-edit

- name: Push
if: steps.check.outputs.behind != '0'
run: git push origin main
1 change: 1 addition & 0 deletions atom/model_engine/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def postprocess(
continue
token_ids = prev_token_ids[seq.id]
num_new_token = len(token_ids)
num_rejected = 0
self.update_spec_stats(num_new_token)
idx = fwd_output.req_ids.index(seq.id)
if is_deferred_out or self.use_spec:
Expand Down
17 changes: 15 additions & 2 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,19 @@ def prefill_attention_triton(
if ctx.is_prefill:
k_cache = k.unsqueeze(1)
v_cache = v.unsqueeze(1)
block_tables = attn_metadata.fake_block_tables
# Create fake block_tables for prefill: each token is its own
# "block" (block_size=1). Shape [num_seqs, max_seqlen_k].
batch_size = attn_metadata.cu_seqlens_k.shape[0] - 1
max_len = attn_metadata.max_seqlen_k
block_tables = torch.zeros(
batch_size, max_len, dtype=torch.int32, device=q.device
)
for i in range(batch_size):
s = attn_metadata.cu_seqlens_k[i].item()
e = attn_metadata.cu_seqlens_k[i + 1].item()
block_tables[i, : e - s] = torch.arange(
s, e, dtype=torch.int32, device=q.device
)

o = torch.empty_like(q)
descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1])
Expand Down Expand Up @@ -407,7 +419,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext):
ctx = fwd_ctx.context

if ctx.is_prefill:
return self.prefill_attention
# Always use Triton prefill (no CK/flash_attn_varlen_func dependency)
return self.prefill_attention_triton
else:
if self.use_triton_attn:
return self.paged_attention_triton
Expand Down
43 changes: 27 additions & 16 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,22 @@ def _forward_prefill_mha(

k = torch.cat((k_nope, k_rope.expand((*k_nope.shape[:-1], -1))), dim=-1)

output = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
cu_seqlens_k=attn_metadata.cu_seqlens_k,
max_seqlen_q=attn_metadata.max_seqlen_q,
max_seqlen_k=attn_metadata.max_seqlen_k,
min_seqlen_q=attn_metadata.min_seqlen_q,
dropout_p=attn_metadata.dropout_p,
softmax_scale=self.scale,
causal=True,
)
# Use PyTorch SDPA for MLA prefill attention (no CK dependency)
import torch.nn.functional as F

cu_q = attn_metadata.cu_seqlens_q
cu_k = attn_metadata.cu_seqlens_k
num_seqs = cu_q.shape[0] - 1
outputs = []
for i in range(num_seqs):
qi = q[cu_q[i] : cu_q[i + 1]].transpose(0, 1).unsqueeze(0)
ki = k[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
vi = v[cu_k[i] : cu_k[i + 1]].transpose(0, 1).unsqueeze(0)
oi = F.scaled_dot_product_attention(
qi, ki, vi, is_causal=True, scale=self.scale
)
outputs.append(oi.squeeze(0).transpose(0, 1))
output = torch.cat(outputs, dim=0)

return self.o_proj(output.flatten(start_dim=-2))

Expand Down Expand Up @@ -446,7 +449,8 @@ def _forward_prefill_mla(
max_q_len = 1

if kv_c_and_k_pe_cache.numel() > 0:
if self.kv_cache_dtype.startswith("fp8"):
if self.kv_cache_dtype.startswith("fp8") and max_q_len == 1:
# mla_decode_fwd supports fp8 scales but only max_seqlen_q=1
mla_decode_fwd(
q,
kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]),
Expand All @@ -463,9 +467,16 @@ def _forward_prefill_mla(
kv_scale=self._k_scale,
)
else:
# mla_prefill_fwd supports arbitrary max_seqlen_q but no fp8 scales
q_for_prefill = q.to(self.dtype) if q.dtype != self.dtype else q
kv_for_prefill = (
kv_c_and_k_pe_cache.to(self.dtype)
if kv_c_and_k_pe_cache.dtype != self.dtype
else kv_c_and_k_pe_cache
)
mla_prefill_fwd(
q,
kv_c_and_k_pe_cache.view(-1, 1, 1, q.shape[-1]),
q_for_prefill,
kv_for_prefill.view(-1, 1, 1, q.shape[-1]),
o,
paged_cu_seqlens_q,
paged_kv_indptr,
Expand Down
50 changes: 50 additions & 0 deletions atom/model_ops/attentions/aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,56 @@ def prepare_prefill(self, batch: ScheduledBatch):
bs = batch.total_seqs_num_prefill
sum_scheduled_tokens = batch.total_tokens_num_prefill
var = self.model_runner.forward_vars

# Prepare paged KV metadata for MLA prefill paths
# (needed by mla_prefill_fwd for bf16, unified_attention for fp8)
if batch.block_tables:
context_lens = np.asarray(batch.context_lens[:bs], dtype=np.int32)
num_blocks_per_seq = cdiv(context_lens, self.block_size)
kv_indptr = np.cumsum(num_blocks_per_seq)
sum_blocks = kv_indptr[-1]

dst = var["kv_indices"].np
offset = 0
for i in range(bs):
bt = batch.block_tables[i]
n = len(bt)
dst[offset : offset + n] = bt
offset += n
sum_blocks_before_converted = offset

var["kv_indptr"].np[0] = 0
var["kv_indptr"].np[1 : bs + 1] = kv_indptr

attn_metadata.kv_indptr = var["kv_indptr"].copy_to_gpu(bs + 1)
attn_metadata.kv_indices = var["kv_indices"].copy_to_gpu(
sum_blocks_before_converted
)
attn_metadata.kv_last_page_lens = var["kv_last_page_lens"].gpu[:bs]

if self.block_ratio > 1:
kv_indices_convert_triton(
var["kv_indices"].gpu[:sum_blocks_before_converted],
var["kv_indices_converted"].gpu[:sum_blocks],
var["kv_indptr"].gpu[: bs + 1],
self.block_ratio,
self.block_size,
)
attn_metadata.kv_indices = var["kv_indices_converted"].gpu[:sum_blocks]

# Prepare block_tables for unified_attention (fp8 prefill)
if attn_metadata.block_tables is None:
self.prepare_block_tables(batch)
attn_metadata.block_tables = var["block_tables"].copy_to_gpu(bs)
if self.block_ratio > 1:
block_table_convert_triton(
var["block_tables"].gpu[:bs],
var["block_tables_converted"].gpu[:bs],
var["context_lens"].gpu[:bs],
self.block_ratio,
)
attn_metadata.block_tables = var["block_tables_converted"].gpu[:bs]

if self.is_sparse and attn_metadata.max_seqlen_k > self.index_topk:
if attn_metadata.block_tables is None:
self.prepare_block_tables(batch)
Expand Down
Loading
Loading