C++ transformer inference engine. Loads pretrained GPT-Neo weights, tokenizes with GPT-2 BPE, generates text autoregressively with KV-cached multi-head attention. No external dependencies — weights loaded via fread, math done with hand-written NEON intrinsics, memory managed with malloc/free.
Tested with TinyStories-33M (768d, 4 layers, 16 heads, 50257 vocab).
$ ./engine tinystories.bin --tokens "7454 2402 257 640" --tokenizer tokenizer.bin --temperature 0.8
Generated: , there was a little family lived in a big garden
in a house in a far far far, far, lived.. sunny world. [EOS]
124.53 tokens/sec | 8.03 ms/tok | Apple M2, single thread
Apple M2, 8 GB, single-threaded, -O2:
| Config | Params | tok/s | ms/tok |
|---|---|---|---|
| 128d / 2L / 4H | 460K | 17,627 | 0.06 |
| 256d / 4L / 8H | 3.28M | 3,156 | 0.32 |
| TinyStories-33M | 33M | 125 | 8.0 |
| 512d / 6L / 8H | 26.7M | 519 | 1.93 |
The small-model numbers are inflated by the benchmark fitting in L2. The 33M number is the honest one — it's memory-bound, which is typical for real inference.
g++ -O2 -std=c++17 -I include src/*.cpp -o engineTests:
g++ -O2 -std=c++17 -I include src/matmul.cpp src/utils.cpp src/transformer.cpp \
src/loader.cpp src/tokenizer.cpp tests/test_engine.cpp -o test_engine
./test_engine # 49 testsRequires transformers and torch for the one-time weight export. The engine itself has no Python dependency.
python3 scripts/export_tinystories.py
./engine models/tinystories.bin \
--tokens "7454 2402 257 640" \
--tokenizer models/tokenizer.bin \
--temperature 0.8 \
--gen_tokens 100The --tokens flag takes pre-tokenized IDs from the GPT-2 tokenizer. This bypasses the C++ BPE encoder, which doesn't implement GPT-2's regex pre-tokenization and can produce slightly different splits. For demo accuracy, tokenize in Python first:
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("roneneldan/TinyStories-33M")
print(tok.encode("Once upon a time")) # [7454, 2402, 257, 640]Pre-norm GPT-2/Neo. Each token goes through:
tok_embed[token] + pos_embed[pos]
-> N x { LN -> MHA(Q,K,V, KV cache) + residual -> LN -> FFN(GELU) + residual }
-> LN -> lm_head -> logits -> argmax or temperature sampling
Attention uses implicit causal masking — the inner loop only iterates [0..pos], so no explicit mask tensor is allocated.
KV cache layout is [n_layers, seq_len, dim] contiguous. Each new token writes one row via memcpy. This avoids the quadratic recompute that makes naive attention unusable for generation.
Header (28 bytes, 7 x int32):
0x454E4731 magic
vocab_size 50257 for GPT-2
dim 768 for TinyStories-33M
hidden_dim 3072
n_layers 4
n_heads 16
seq_len 512
Body (float32, row-major, sequential):
token_embedding [vocab_size, dim]
pos_embedding [seq_len, dim]
per layer:
wq, wk, wv, wo [dim, dim] each
bo [dim]
w1, b1 [dim, hidden_dim], [hidden_dim]
w2, b2 [hidden_dim, dim], [dim]
ln1 gamma/beta [dim] each
ln2 gamma/beta [dim] each
ln_final gamma/beta [dim] each
lm_head [dim, vocab_size]
The v1 format (no magic, no pos embeddings, no biases) is auto-detected and still loads.
- BPE encoder doesn't implement GPT-2's regex-based word splitting. Use
--tokensfor exact results. - No quantization. Weights are float32. The 33M model is 404 MB on disk.
- Single-threaded. No batching.
- GELU uses the tanh approximation (
gelu_new), not exact GELU.
MIT
