Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 84 additions & 27 deletions applications/llama_3.2_1b/src/block/gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,25 @@ def __init__(

# Initialize AIE RoPE operator
if self.cfg["use_aie_rope"]:
self.aie_rope = AIERope(
num_aie_columns=1,
num_channels=1,
size=self.prompt_length * self.head_dim,
last_dim=self.head_dim,
self.aie_rope_prefill_k = AIERope(
rows=self.prompt_length * self.num_kv_groups,
cols=self.head_dim,
angle_rows=self.prompt_length,
)
self.aie_rope_prefill_q = AIERope(
rows=self.prompt_length * self.num_heads,
cols=self.head_dim,
angle_rows=self.prompt_length,
)
self.aie_rope_decode_k = AIERope(
rows=self.num_kv_groups,
cols=self.head_dim,
angle_rows=1,
)
self.aie_rope_decode_q = AIERope(
rows=self.num_heads,
cols=self.head_dim,
angle_rows=1,
)

# Initialize fused AIE MHA operator
Expand Down Expand Up @@ -182,6 +196,10 @@ def forward(self, x, mask, angles, input_pos=None):
is_prefill = input_pos is None
is_decode = input_pos is not None

# Step 1.
# ---
# Linear projections -- calculate quries, keys and values by multiplying embedding vector (in decode) or matrix (in prefill) with weight matrices

# Choose between GEMM (prefill) and GEMV (decode) based on KV cache usage
if self.cfg["use_kv_cache"] and is_decode and self.cfg["use_aie_gqa_gemv"]:
# Decode phase with KV cache - use GEMV for single token
Expand Down Expand Up @@ -219,10 +237,21 @@ def forward(self, x, mask, angles, input_pos=None):
keys = self.W_key(x)
values = self.W_value(x)

# Each attention head gets its own slice of the embedding dimension.
# For each head, we have query, key and value.
# In grouped-query attention, the keys and values are shared across groups of heads.
# Therefore, we have self.num_heads queries, and self.num_kv_groups (== self.num_heads in case of regular attention) keys and values.
# Each head can be applied independently to its subslice of the embedding dimension.
keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Step 2.
# ---
# Apply positional encoding to keys and queries.
# The positional embedding is applied independently to each head.
# It modifies the embedding vectors to encode where in the sequence each token is located.

# Determine angle slice based on KV cache usage and phase
if self.cfg["use_kv_cache"] and is_decode:
# Decode phase with KV cache: use single position
Expand All @@ -232,30 +261,50 @@ def forward(self, x, mask, angles, input_pos=None):
# Prefill phase or no KV cache: use all tokens
angle_slice = angles[:num_tokens, :]

# Apply RoPE with AIE or CPU fallback
def apply_rope_and_transpose(tensor, num_heads_dim, angle_slice):
expected_seq_len = (
1 if (self.cfg["use_kv_cache"] and is_decode) else self.prompt_length
)
can_use_aie = (
self.cfg["use_aie_rope"]
and tensor.shape[-1] == self.head_dim
and tensor.shape[-2] == expected_seq_len
)

if can_use_aie:
# AIE RoPE path: flatten -> apply -> reshape -> transpose
tensor = self.aie_rope(tensor.view(b, num_tokens, -1), angle_slice)
return tensor.view(
# Apply RoPE with AIE
def apply_rope_and_transpose(aie_op, tensor, num_heads_dim, angle_slice):
angle_slice = angle_slice.to(dtype=tensor.dtype)
if self.cfg["use_aie_rope"]:
result = aie_op(
tensor.view(num_tokens * num_heads_dim, self.head_dim), angle_slice
)
result = result.view(
b, num_tokens, num_heads_dim, self.head_dim
).transpose(1, 2)
else:
# CPU RoPE path: transpose -> apply
tensor = tensor.transpose(1, 2)
return apply_rope(tensor, angle_slice)

keys = apply_rope_and_transpose(keys, self.num_kv_groups, angle_slice)
queries = apply_rope_and_transpose(queries, self.num_heads, angle_slice)
transposed = (
tensor.view(num_tokens, num_heads_dim, self.head_dim)
.transpose(0, 1)
.contiguous()
)
result = apply_rope(
transposed.view(1, num_heads_dim, num_tokens, self.head_dim),
angle_slice,
)
# ref = apply_rope(transposed.view(1, num_heads_dim, num_tokens, self.head_dim), angle_slice)
# assert torch.allclose(ref, result, atol=0.7, rtol=0.07), "AIE RoPE result does not match reference"
return result

keys = apply_rope_and_transpose(
(
(self.aie_rope_prefill_k if is_prefill else self.aie_rope_decode_k)
if self.cfg["use_aie_rope"]
else None
),
keys,
self.num_kv_groups,
angle_slice,
)
queries = apply_rope_and_transpose(
(
(self.aie_rope_prefill_q if is_prefill else self.aie_rope_decode_q)
if self.cfg["use_aie_rope"]
else None
),
queries,
self.num_heads,
angle_slice,
)
values = values.transpose(1, 2)

if self.cfg["use_kv_cache"]:
Expand All @@ -272,10 +321,18 @@ def apply_rope_and_transpose(tensor, num_heads_dim, angle_slice):
keys = cached_keys
values = cached_values

# Expand keys and values to match query heads for all cases (grouped query attention)
# Step 3.
# ---
# Since the keys and values are shared across groups of heads in grouped-query attention,
# we now expand (repeat) the same keys and values so that each head has its own keys and values.
keys = keys.repeat_interleave(self.group_size, dim=1)
values = values.repeat_interleave(self.group_size, dim=1)

# Step 4.
# ---
# Compute attention scores (indepdentently for each head), apply softmax to get attention weights, then apply those weights to the attention values to get output.
# Attention scores are the dot-product of queries and keys.

# Use fused AIE MHA if enabled and conditions are met
if is_prefill or not self.cfg["use_kv_cache"]:
if (
Expand Down
Loading