A high-performance JAX/Flax implementation of NanoGPT optimized for TPU training.
This project reimplements Andrej Karpathy's NanoGPT in JAX, focusing on performance and scalability. It leverages JAX's automatic differentiation and compilation capabilities along with Flax's neural network layers to create an efficient and maintainable codebase that runs distributedly on TPUs.
- 🚀 Full JAX/Flax implementation optimized for TPUs
- 📈 Distributed training with
@pmap - 🔄 Gradient accumulation for larger effective batch sizes
- 📊 Integrated Weights & Biases logging
- 💾 Support for inference straight from pretrained weights
- 🎯 Cosine learning rate schedule with warmup
We reach a validation loss of 3.17 after 270k steps, at which point the model had converged. This took roughly 18 hours on a TPU v3-8.
Additionally, when training on a TPU, we hit an average duty cycle of 77%, indicating good accelerator utilization.
- Clone the repository:
git clone https://github.com/plippmann/nanogpt-jax && cd nanogpt-jax- Set up the environment:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash Miniconda3-latest-Linux-x86_64.sh
source ~/miniconda3/bin/activate
conda create -n myenv python=3.10
conda activate myenv- Install JAX (TPU version):
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html- Install dependencies:
pip install -r requirements.txtnanogpt-jax/
├── nanogpt/
│ ├── model.py # Core GPT-2 implementation
│ ├── train.py # Training loop and configuration
│ ├── inference.py # Text generation utilities
│ ├── pretrained.py # Load pretrained weights
│ └── tests.py # Sanity checks for model.py
├── data/
│ └── openwebtext/
│ └── prepare.py # Get data
│ └── shakespeare/
│ └── prepare.py # Get data
└── requirements.txt
- Implements the GPT-2 architecture using Flax's neural network modules
- Supports configurable model sizes
- Own implementation of causal self-attention as well as Flax's version
- Distributed training across TPU cores using
@pmap - Gradient accumulation for higher effective batch sizes
- Learning rate scheduling with warmup and cosine decay
- AdamW optimizer for now
- Integrated W&B logging for training metrics
Implement the model in JAXWrite testsLoad pretrained weightsPerform inference from pretrained weightsTrain the model on TPUsMake it fast with @pmap/@jitRun inference on the trained model- Post-training fun
Implement RoPE, Muon optimizer, and other improvements
- Prepare training data:
python data/openwebtext/prepare.pyUpload the resulting train.bin and val.bin to your GCP storage bucket.
- Create a TPU VM:
ZONE=europe-west4-a
TPU_TYPE=v3-8
VM_NAME=jax-gpt-v3-8
gcloud alpha compute tpus tpu-vm create $VM_NAME \
--zone=$ZONE \
--accelerator-type=$TPU_TYPE \
--version=tpu-ubuntu2204-base \
--preemptible- SSH into the VM:
gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE- Start training:
python train.py- Generate text from best checkpoint:
python inference.py --init_from resume --checkpoint_type best- Clean up:
gcloud alpha compute tpus tpu-vm delete $VM_NAME --zone=$ZONE- NanoGPT by Karpathy
- gpt-jax by Penn Jenks
- PyTorch to JAX blog by Douglas Jia
Feel free to reach out at p.lippmann@tudelft.nl.
