A simple clean-readable and shape-annotated implementation of Attention is All You Need in PyTorch. A sample onnx file can be found in assets/transformer.onnx for visualization purposes.
It was tested on synthetic data, try to use the attention plots to figure out the transformation used to create the data!
- Positional Embeddings not included, similar to
nn.Transformerbut you can find an implementation inusage.ipynb. - Parallel
MultiHeadAttentionoutperforms the for loop implementation significantly, as expected. - Assumes
batch_first=Trueinput by default and cna't be changed. - Uses
einsumfor attention computation rather thanbmmfor readability, this might impact performance.