-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathattention_layer.py
More file actions
171 lines (134 loc) · 5.57 KB
/
attention_layer.py
File metadata and controls
171 lines (134 loc) · 5.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from __future__ import annotations
from typing import Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from paged_kv import PagedKVCache
from paged_attention import paged_attention
class ToyAttention(nn.Module):
"""A minimal self‑contained attention layer that uses a Paged KV‑Cache.
Parameters
----------
d_model : int, default 128
Embedding dimension.
n_heads : int, default 4
Number of attention heads. Must divide ``d_model``.
block_size : int, default 8
How many tokens fit in one physical cache page.
n_blocks : int, default 1024
Total pages to pre‑allocate ( → max context length = ``block_size * n_blocks`` ).
device : str | torch.device, optional
Where to place model parameters. Cache follows the input ``x`` device at runtime.
"""
def __init__(
self,
d_model: int = 128,
n_heads: int = 4,
block_size: int = 8,
n_blocks: int = 1024,
device: str | torch.device | None = None,
):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.Wqkv = nn.Linear(d_model, 3 * d_model, bias=False, device=device)
self.Wo = nn.Linear(d_model, d_model, bias=False, device=device)
# Each ToyAttention instance owns its own cache
self.cache = PagedKVCache(
block_size=block_size,
num_blocks=n_blocks,
num_heads=n_heads,
head_dim=self.head_dim,
device=device
)
def forward(self, x: torch.Tensor, seq_id: str | int | Literal["0"] = "0") -> torch.Tensor:
assert x.ndim == 3 and x.shape[0] == 1, "ToyAttention currently supports batch size 1"
device, dtype = x.device, x.dtype
B, T, _ = x.shape # B is 1
qkv = self.Wqkv(x) # [1, T, 3*d_model]
q, k, v = qkv.chunk(3, dim=-1) # each [1, T, d_model]
def split_heads(t: torch.Tensor) -> torch.Tensor:
return (
t.view(B, T, self.n_heads, self.head_dim)
.transpose(1, 2) # [1, H, T, D]
.contiguous()
)
q = split_heads(q)
k = split_heads(k)
v = split_heads(v)
for t in range(T):
self.cache.append_kv(seq_id, k[0, :, t], v[0, :, t]) # each is [H, D]
ctx = paged_attention(q, self.cache, seq_id) # [1, H, T, D]
ctx = ctx.transpose(1, 2).reshape(B, T, self.d_model) # [1, T, d_model]
return self.Wo(ctx)
class ContiguousAttention(nn.Module):
"""
Minimal attention layer with a fixed-size contiguous KV cache.
This simulates the layout used in traditional inference stacks, where each
request gets a pre-reserved tensor block. The cache grows linearly and
supports incremental decoding, assuming a single active sequence.
Parameters
----------
d_model : int
Embedding dimension.
n_heads : int
Number of attention heads. Must divide d_model.
max_ctx : int
Number of tokens to pre-allocate space for. Acts as a hard cap on context length.
device : str | torch.device, optional
Device to place parameters and cache.
"""
def __init__(
self,
d_model: int,
n_heads: int,
max_ctx: int = 2048,
device: str | torch.device | None = None,
):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.max_ctx = max_ctx
self.Wqkv = nn.Linear(d_model, 3 * d_model, bias=False, device=device)
self.Wo = nn.Linear(d_model, d_model, bias=False, device=device)
# pre-reserve a single contiguous buffer for K and V
# Shape: [max_ctx, n_heads, head_dim]
self.register_buffer(
"k_cache",
torch.empty(max_ctx, n_heads, self.head_dim, device=device)
)
self.register_buffer(
"v_cache",
torch.empty_like(self.k_cache)
)
self.cur_pos: int = 0
def kv_bytes_allocated(self):
bytes_per_token = 2 * self.k_cache.element_size() * self.n_heads * self.head_dim
return bytes_per_token * self.max_ctx
def kv_bytes_used(self) -> int:
bytes_per_token = 2 * self.k_cache.element_size() * self.n_heads * self.head_dim
return bytes_per_token * self.cur_pos
def reset_cache(self):
self.cur_pos = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
assert B == 1, "Only batch size 1 supported."
qkv = self.Wqkv(x) # [1, T, 3 * d_model]
q, k, v = qkv.chunk(3, dim=-1)
def split_heads(t: torch.Tensor):
return t.view(1, T, self.n_heads, self.head_dim).transpose(1, 2).contiguous()
q = split_heads(q) # [1, H, T, D]
k = split_heads(k)[0] # [H, T, D]
v = split_heads(v)[0] # [H, T, D]
for t in range(T):
self.k_cache[self.cur_pos] = k[:, t] # [H, D]
self.v_cache[self.cur_pos] = v[:, t]
self.cur_pos += 1
k_full = self.k_cache[:self.cur_pos].permute(1, 0, 2).unsqueeze(0) # [1, H, L, D]
v_full = self.v_cache[:self.cur_pos].permute(1, 0, 2).unsqueeze(0)
ctx = F.scaled_dot_product_attention(q, k_full, v_full, is_causal=True)
return self.Wo(ctx.transpose(1, 2).reshape(1, T, self.d_model))