This codebase is currently being refactored for Github and is/has been used for submissions to ACL, EMNLP, and NeurIPS 2025.
Quick Llama provides an optimized implementation targeting Llama 3.2 1B. It is designed for efficient gradient-checkpoint-free training on 80GB GPUs.
- Optimized Llama 3.2 1B: A Torch-based implementation with some Triton kernels and FlexAttention.
- Efficient Training: Enables gradient-checkpoint-free training to reduce compute: instruction tune on 1 million training examples within two hours using 4x A100 80GB with a sequence length of 4096 and a global batch size of 32.
- Distributed Training: Built-in support for distributed data parallelism using HuggingFace Accelerate. Defaults to FP16 communication and overlapped backwards pass computation and communication.
- Efficient Attention Implementation: Leverages Flex Attention with efficient block-sparsity masks and proper sequence packing. This allows us to simply skip computing on tokens we would otherwise mask out, like padding tokens, or tokens from sequences other than a specific sample.
- Optimized Vocab. Head and RoPE: Includes cached Triton-based Rotary Positional Encodings (RoPE) and efficient cross-entropy implementations (from Apple and Unsloth.AI, with some tweaks to fix negative loss issues).
- Flexible Data Handling: Included data pre-processor automatically adds necessary auxiliary data for efficient FlexAttention, enabling block-sparsity masks and sequence packing. The processor also allows for proper sequence packing for SSMs and prefix-LMs.
- Utilities Included:
- Automatic weight download and conversion of Llama 3.2 1B's weights from HuggingFace to our Torch format.
- Example training script with configuration management, logging (TensorBoard support via Accelerate), and checkpointing.
- Basic command-line interface for conversing with the model.
- Evaluation using EleutherAI's LM Harness.
- Performance: Compatible with
torch.compilefor optimal performance.
- Splash Attention
- FlexAttention
- Apple's Cross-Entropy Loss
- Unsloth.AI (RoPE Triton Kernels and Cross-Entropy Loss)
- Meta's Llama Reference Implementation (RoPE)
- HuggingFace's Llama 3.2 Implementation
- Hardware: 4x A100 80GB GPUs
- Parameters: Sequence lengths >= 4096, batch sizes 32-64.
- Python >= 3.8
- PyTorch >= 2.5
- Triton >= 3.1
- Hugging Face Libraries:
transformers,accelerate,datasets,huggingface-hub - Other libraries:
safetensors,numpy,tqdm
Hugging Face Token:
You must set the environment variable HF_TOKEN with your HuggingFace token that has been granted access to Meta's Llama 3.2 repository to download the weights.
export HF_TOKEN="your_hf_token_here"
python quick_llama/minimal_training_example.py
## Alternatively:
HF_TOKEN="your_hf_token_here" python quick_llama/minimal_training_example.py
NUM_GPUS=4
HF_TOKEN="your_hf_token_here" accelerate launch --num_processes ${NUM_GPUS} quick_llama/minimal_training_example.pyIt is highly recommended to install this in a virtual environment (e.g., using venv or conda).
# Create and activate a virtual environment
python -m venv .venv
## Install dependencies
## Choose the correct version based on your CUDA version
### You can see your CUDA version by running `nvcc --version` or `nvidia-smi`
### Cuda 11.8
pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
### Cuda 12.6
pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126
### Cuda 12.8
pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128pip install --upgrade git+https://www.github.com/samblouir/quick_llamagit clone https://github.com/samblouir/quick_llama
cd quick_llama
# Install the base package in editable mode
pip install -e .If you are having issues with flash-attn (one reason this can happen is if you update Torch):
pip uninstall flash-attn -y
pip install flash-attn --no-build-isolation --force-reinstall --no-depsIf you get an error like this:
RuntimeError: Failed to import transformers.models.llama.modeling_llama because of the following error (look up to see its traceback):
operator torchvision::nms does not exist
Try running this:
pip install --upgrade transformersThis project uses HuggingFace Accelerate for handling distributed training and device placement.
# Minimal Example
# This will download the model weights and run a minimal training example.
# It also shows prints out a processed batch by the data processor, which adds necessary auxiliary information for FlexAttention's sparsity.
cd quick_llama/src/quick_llama
python minimal_training_example.pyAn example SLURM script is provided for running on a cluster.
You should modify the sbatch_quick_llama.sh script to suit your cluster, especially the partition and qos settings.
sbatch sbatch_quick_llama.sh# 1 GPU on a local machine
cd quick_llama/src/quick_llama
python train_wrapper.py \
--dataset_name "teknium/OpenHermes-2.5" \
--batch_size_per_device 32 \
--gradient_accumulation_steps 4 \
--num_epochs 1 \
--lr 5e-5 \
--output_dir "./training_logs" \
--steps_between_evals 200
# 4 GPUs on a local machine
cd quick_llama/src/quick_llama
accelerate launch --num_processes 4 train_wrapper.py \
--dataset_name "teknium/OpenHermes-2.5" \
--batch_size_per_device 8 \
--gradient_accumulation_steps 4 \
--num_epochs 1 \
--lr 5e-5 \
--output_dir "./training_logs" \
--steps_between_evals 200
# Multi-node setup: 2 machines with 4 GPUs each
cd quick_llama/src/quick_llama
accelerate launch --num_processes 8 --num_machines 2 minimal_training_example.py \
--dataset_name "teknium/OpenHermes-2.5" \
--batch_size_per_device 8 \
--gradient_accumulation_steps 4 \
--num_epochs 1 \
--lr 5e-5 \
--output_dir "./training_logs" \
--steps_between_evals 200Training runs are configured via command-line arguments. Key arguments (defined in config.py) include:
--model_name: HF model identifier (default:meta-llama/Llama-3.2-1B).--dataset_name: HF dataset identifier (default:teknium/OpenHermes-2.5).--output_dir: Directory to save logs and checkpoints (default:logs/).--num_epochs: Number of training epochs.--batch_size_per_device: Batch size for each GPU.--lr: Learning rate.--steps_between_evals: How often to run validation.--mixed_precision:bf16(default),fp16, orno.- ... and others (see
config.pyor runpython train.py --help).
Configuration files (config.json), logs, TensorBoard data, and checkpoints will be saved under {output_dir}/{config_hash}/.
Use accelerate launch to start a training run. Accelerate will automatically handle multi-GPU/multi-node setups based on its configuration.
-
Configure Accelerate (if needed): Run
accelerate configand follow the prompts to set up your distributed environment (number of GPUs, etc.). -
Launch Training: Adjust arguments as needed.
accelerate launch train.py \ --dataset_name "teknium/OpenHermes-2.5" \ --batch_size_per_device 8 \ --gradient_accumulation_steps 4 \ --num_epochs 1 \ --lr 5e-5 \ --output_dir "./training_logs" \ --steps_between_evals 200 \ # Add other arguments from config.py parse_arguments() function as neededconfig.py get_config() automatically sets log_dir and checkpoint_dir.
- Logs will be printed to the console and the project directory.
- Checkpoints are saved periodically in the
output_dir. - TensorBoard logs can be viewed by running
tensorboard --logdir ./training_logs.
# Nuclear option for flash-attn issues:
## Re-installs flash-attn and dependencies
## Warning: This will changes the versions of some packages
pip uninstall flash-attn -y
pip install flash-attn --no-build-isolation --force-reinstallDocumentation for this module is coming soon. This allows you to run the EleutherAI LM Harness without needing to re-export the model to the HuggingFace format and running lm_eval.
- Merge in support for Birdie's reward-driven mixture-of-denoisers for second-stage pretraining.
- Merge in support for fine-tuning Llama Vision 11B.
- Merge in support for finetuning LoRA adapters.
- Merge in support for exporting back to VLLM and HuggingFace-compatible formats.
- Finish sequence parallelism implementation.
- Add optional Gradient Checkpointing for larger models/settings.
- Add detailed instructions for LM Harness evaluation and CLI inference.
Contributions are welcome! Please feel free to open an issue to report bugs or suggest features. Pull requests are also appreciated.
Apache 2.0
