- Attention is all you need
- Linear attention
- Gumbel softmax + straight through trick
- Cluster attention paper and sparse attention paper
Self-attention can be viewed as message passing on a fully connected graph G with self-edges.
Noncausal two-level attention with supernodes block design:
| Step | Runtime | |
|---|---|---|
| 1 | Partition nodes invariantly into cliques via 1D proj + sort | O(n*logn) |
| 2 | Create supernode in each clique | O(sqrt(n)) |
| 3 | Each supernode attends to its clique | O(n) |
| 4 | Inter-clique (supernode) attention | O(n) |
| 5 | Intra-clique attention | O(n*sqrt(n)) |
Small Transformer LM comparing attention mechanisms:
- MHA: standard causal softmax attention
- LinearAttention (LA): causal kernelized attention
- ClusterAttention (CA): block-based √T clustered attention
- LearnedClusterAttention (LCA): learned cluster assignments via Gumbel softmax
- ClusterKernelAttention (CKA): cluster-based attention with low-rank cluster mixing (T√T scaling, O(Ck) mixing via low-rank decomposition)
- SuperClusterAttention (SCA): variant of LCA that includes supernodes that mix and broadcast info to tokens before intra-cluster attention
- FastCKA: fast variant of ClusterKernelAttention (use for training)
| Mechanism | Compute vs T | Memory vs T | Notes |
|---|---|---|---|
| Multi-Head Attention | O(T²) | O(T²) | Exact softmax, most expressive |
| Linear Attention | O(T) | O(1) in T | Kernel prefix sums, single global summary |
| Clustered Kernel Attn | O(T^(3/2)) | O(T^(3/2)) | Clusters + low-rank mixing, mid expressiveness |
- Dataset:
enwik8(first 100M bytes of Wikipedia).
Download into project:
curl -O https://data.deepai.org/enwik8.zip
unzip enwik8.zippython train.pyThis runs all three variants (MHA, LinearAttention, ClusterAttention) and reports bits-per-byte.
- 1. Upload
enwik8into the Modal volume
modal run modal_app.py::upload_enwik8_from_local- 2. Deploy the training endpoint
modal deploy modal_app.pyOpen the URL shown in the deploy output, click the run_training web endpoint.
transformer/– Transformer layers and language model wrappervariants/– attention implementations (MHA, LinearAttention, ClusterAttention, LearnedClusterAttention, ClusterKernelAttention, FastCKA)train.py– training + evaluation loop over enwik8modal_app.py– Modal endpoints for remote training