This repository provides a clear and minimal implementation for training a Gemma-like language model from scratch using PyTorch. The project is structured to be easily understandable, with a clear separation between the core model logic and the training/data preparation scripts.
The implementation is heavily inspired by and builds upon the foundational work from Andrej Karpathy's nanoGPT.
This repository supports two primary use cases:
- Inference with Official Gemma Weights: The inference_google_gemma.py script uses the official Gemma tokenizer and loads the pre-trained 270M model from Hugging Face. This demonstrates the architectural compatibility of our implementation.
- Training a Model From Scratch: The prepare_dataset.py and train.py scripts allow you to train a model on your own data (e.g., TinyStories). For this workflow, we use the simpler and faster GPT-2 tokenizer (tiktoken). This is a self-contained training pipeline, and the resulting model should be used with inference_custom.py, which also uses the GPT-2 tokenizer.
The Gemma 3 model is a decoder-only transformer, similar to other popular language models. It is composed of the following key components:
- Token Embeddings: The model starts by converting input tokens into dense vectors of a specified dimension (
emb_dim). - Transformer Blocks: The core of the model is a series of transformer blocks (
n_layers), each containing:- Attention Mechanism: Gemma 3 uses a sophisticated attention mechanism that combines global and sliding-window attention. This allows the model to handle long contexts efficiently while still capturing global dependencies.
- Feed-Forward Network: A position-wise feed-forward network with a SwiGLU activation function.
- Layer Normalization: RMSNorm is used for layer normalization.
- Rotary Positional Embeddings (RoPE): RoPE is used to inject positional information into the model.
- Output Head: A final linear layer that projects the output of the transformer blocks back to the vocabulary space to produce logits.
This implementation supports both a custom configuration and the Gemma 3 270M configuration, as defined in the config.py file.
The model takes a sequence of tokens as input and processes them through the embedding layer. The resulting embeddings are then passed through a series of transformer blocks. Each block applies self-attention and a feed-forward network to the input. The attention mechanism is masked to prevent tokens from attending to future tokens, and it can be either global or sliding-window, depending on the layer's configuration.
The final output of the transformer blocks is then passed through a normalization layer and a linear layer to produce the logits for the next token in the sequence.
During inference, the model generates text autoregressively. It takes a starting sequence of tokens, predicts the next token, appends it to the sequence, and repeats the process until a specified number of new tokens have been generated.
For better organization and modularity, the project files have been arranged as follows. This separates the user-facing scripts from the core, importable Python package.
gemma_from_scratch/
├── README.md
├── requirements.txt
│
├── prepare_dataset.py # User-facing script to download and process data
├── train.py # User-facing script to train the model
├── inference_custom.py # Test the inference capabilities of your custom model
├── inference_google_gemma.py # Test the inference capabilities of the original Gemma 270M
│
├── gemma_scratch/ # Core source code as a Python package
│ ├── __init__.py # Makes 'gemma_scratch' a package
│ ├── config.py # Model hyperparameters
│ ├── layers.py # The TransformerBlock implementation
│ ├── model.py # The Gemma model definition
│ ├── normalization.py # RMSNorm implementation
│ ├── rope.py # RoPE (Rotary Positional Embeddings) implementation
│ └── tokenizer.py # Tokenizer utilities
├── evals/ # Scripts and utilities for model evaluation
│ ├── EVAL.md # Documentation for evaluation
│ └── evaluate_model.py # Script to evaluate the trained model
└── utils/ # Utility scripts for various tasks
├── count_parameters.py # Script to count model parameters
├── count_tokens.py # Script to count tokens in a dataset
├── generate_model_summary.py # Script to generate a model summary
├── GPU_performance_benchmark.py # Script for GPU performance benchmarking
└── model_summary.txt # Placeholder for model summary output
-
Clone the repository:
git clone https://github.com/lmassaron/gemma_from_scratch.git cd gemma_from_scratch -
Create a virtual environment and install dependencies:
python -m venv venv source venv/bin/activate # On Windows, use `venv\Scripts\activate` pip install -r requirements.txt
The process is divided into two main steps: data preparation and model training.
This script downloads a dataset from the Hugging Face Hub, tokenizes it using a GPT-2 tokenizer, and saves the token sequences into efficient binary files (.bin) for rapid loading during training.
-
To run with the default dataset (
roneneldan/TinyStories):python prepare_dataset.py
This command will create a directory named
tinystories_data/containingtrain.binandvalidation.bin. -
To use a different dataset from the Hub and a custom output directory:
python prepare_dataset.py <huggingface_dataset_name> --output-dir ./data
For example:
python prepare_dataset.py "c4" --data-subset "en.noblocklist" --output-dir ./c4_data
Once the dataset is prepared, you can start training the model. This script handles the entire training loop, including optimization, learning rate scheduling, and periodic evaluation.
-
To train using the default dataset location: (This assumes you ran the default
prepare_dataset.pycommand in the previous step)python train.py
-
To point the training script to a custom data directory: (This is required if you used the
--output-diroption when preparing the data)python train.py --data-dir ./data
The training script will save the following outputs in the root directory:
best_model_params.pt: The state dictionary of the model that achieved the lowest validation loss.loss_plot.png: A plot showing the training and validation loss curves over time.
-
prepare_dataset.py: A flexible data processing script. It parallelizes the tokenization step across all available CPU cores for maximum efficiency and uses memory-mapped NumPy arrays to handle datasets larger than RAM. -
train.py: The main training loop. It implements modern training best practices:- Mixed-Precision Training: Uses
torch.amp.autocastwithbfloat16for faster training and reduced memory usage on supported hardware. - Optimizer: Employs the AdamW optimizer, which adds weight decay for better regularization.
- Learning Rate Scheduler: Uses a
SequentialLRscheduler that combines a linear warmup phase with a cosine decay, helping to stabilize training. - Gradient Accumulation: Allows for training with large effective batch sizes even on memory-constrained GPUs.
- Gradient Clipping: Prevents exploding gradients by clipping the norm of the gradients before the optimizer step.
- Mixed-Precision Training: Uses
-
gemma_scratch/(The Core Package):model.py: Defines theGemma3Modelclass, a PyTorchnn.Modulethat assembles the complete transformer architecture.layers.py: Contains theTransformerBlock, the core repeating unit of the model, which includes multi-head attention and the MLP layers.rope.py: Implements Rotary Positional Embeddings (RoPE), a modern technique for injecting positional information into the self-attention mechanism.normalization.py: Provides an efficientRMSNorm(Root Mean Square Normalization) layer, which is used throughout the Gemma architecture instead of traditional LayerNorm.config.py: A simple file to store the model's hyperparameters (e.g., number of layers, heads, embedding dimensions).tokenizer.py: A wrapper for the GPT-2 tokenizer used for encoding the text data.
This repository provides two main scripts for running inference:
-
inference_google_gemma.py: This script uses the pre-trained Gemma 3 270M model from the Hugging Face Hub. It downloads the weights, loads them into the model, and generates text from a list of predefined sentences. This is the recommended script for most users.python inference_google_gemma.py
-
inference_custom.py: This script is for running inference with a custom model. It requires a.pthfile with the model weights. You can use this script to test your own trained models.python inference_custom.py --model-path /path/to/your/model.pth
Both scripts will output the generated text to the console.
- Original Notebook: LLMs-from-scratch/ch05/12_gemma3/standalone-gemma3.ipynb by Sebastian Raschka.
- Hugging Face Model Card: google/gemma-3-270m
This project is licensed under the Apache 2.0 License. See the LICENSE file for more details.