-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathllminference.py
More file actions
97 lines (76 loc) · 3.15 KB
/
llminference.py
File metadata and controls
97 lines (76 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# -*- coding: utf-8 -*-
"""LLMInference.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1M7EssQUFrSYw_epjB0V4jxRTgkOfLNSO
"""
from __future__ import annotations
import time
import torch
import torch.nn.functional as F
from attention_layer import ToyAttention, ContiguousAttention
SEQ_LEN = 256 # tokens generated in the benchmark
D_MODEL = 128
N_HEADS = 4
MAX_CTX = 2048 # contiguous cache reservation
BLOCK_SZ = 8 # paged cache block size
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def correctness_test() -> None:
print("\n=== Incremental correctness test (512 tokens) ===")
x = torch.randn(1, 512, D_MODEL, device=DEVICE)
paged = ToyAttention(d_model=D_MODEL, n_heads=N_HEADS,
block_size=BLOCK_SZ).to(DEVICE)
contig = ContiguousAttention(d_model=D_MODEL, n_heads=N_HEADS,
max_ctx=MAX_CTX).to(DEVICE)
# share weights for a fair numerical comparison
contig.load_state_dict(paged.state_dict(), strict=False)
paged.cache.reset()
contig.reset_cache()
out_paged, out_contig = [], []
for t in range(x.shape[1]): # token-by-token
tok = x[:, t : t+1]
out_paged.append(paged(tok))
out_contig.append(contig(tok))
diff = (torch.cat(out_paged, 1) -
torch.cat(out_contig, 1)).abs().max().item()
print(f"max |Δ| = {diff:.2e}")
assert diff < 1e-5, "outputs diverge"
def run_one(model, label: str, seq_len: int = SEQ_LEN) -> None:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(DEVICE)
t0 = time.time()
for _ in range(seq_len):
tok = torch.randn(1, 1, D_MODEL, device=DEVICE)
model(tok)
torch.cuda.synchronize()
latency = (time.time() - t0) * 1e3 / seq_len # ms / token
peak = torch.cuda.max_memory_allocated(DEVICE)
if label == "paged":
kv_alloc = model.cache.total_allocated_bytes()
kv_used = model.cache.total_used_bytes()
else:
kv_alloc = model.kv_bytes_allocated()
kv_used = model.kv_bytes_used()
waste = kv_alloc - kv_used
waste_pct = 100 * waste / kv_alloc if kv_alloc else 0.0
print(f"\n{label.upper():>10s} | {latency:6.2f} ms/tok")
print(f"{'':>10s} | peak CUDA mem : {peak/1e6:8.2f} MB")
print(f"{'':>10s} | KV allocated : {kv_alloc/1e6:8.2f} MB")
print(f"{'':>10s} | KV used : {kv_used /1e6:8.2f} MB")
print(f"{'':>10s} | KV waste : {waste /1e6:8.2f} MB "
f"({waste_pct:4.1f} %)")
def main() -> None:
# build paged model first and reuse its weights
paged = ToyAttention(d_model=D_MODEL, n_heads=N_HEADS,
block_size=BLOCK_SZ).to(DEVICE)
paged.cache.reset()
contig = ContiguousAttention(d_model=D_MODEL, n_heads=N_HEADS,
max_ctx=MAX_CTX).to(DEVICE)
contig.load_state_dict(paged.state_dict(), strict=False)
contig.reset_cache()
print(f"Running on {DEVICE}\n")
run_one(paged, "paged")
run_one(contig, "contiguous")
correctness_test()
if __name__ == "__main__":
main()