This project is a Rust-based implementation of the AlphaZero algorithm specifically tailored for the game of Connect 4. It uses the burn deep learning framework for its neural network components and is designed to be fully generic over the backend (runs by default on NVIDIA GPUs using the CUDA backend).
The application is controlled via a command-line interface (CLI) that allows users to train new models, play against existing agents, and calculate Elo ratings for a pool of different players.
- AlphaZero Training: A complete training pipeline based on self-play, Monte Carlo Tree Search (MCTS), and a residual neural network.
- Interactive Play: Play a game of Connect 4 directly in your terminal against a trained model, a Minimax agent, or a random-move agent.
- Elo Rating System: An automated system to benchmark different agents against each other to measure their relative strength.
The core of this project is the train part, which kicks off a highly asynchronous and parallelized implementation of the AlphaZero training algorithm. The process is divided into distinct phases within a main loop, where each iteration aims to produce a stronger version of the model.
The entire training lifecycle is monitored in real-time through a terminal-based dashboard using the ratatui crate, providing insights into performance, loss, and system metrics.
Instead of generating games one by one, the system uses an asynchronous architecture built on tokio to play multiple games in parallel.
- Parallel Game Simulation: At the start of each iteration,
NUM_EPISODESare spawned as independent asynchronous tasks. Each task simulates a full game of Connect 4. - Centralized Inference Batching: During a game, the Monte Carlo Tree Search (MCTS) algorithm needs to evaluate board positions using the neural network. To maximize GPU utilization, these inference requests from all parallel games are sent through
tokiochannels to a central processing loop. - Dynamic Batching: The central loop collects these requests and dynamically forms them into a single large batch. This batch is then sent to the GPU for a highly efficient forward pass through the
AlphaZeromodel. - Inference Caching: To avoid redundant GPU work within the same search tree, results from the neural network are cached in a shared
RwLock<HashMap>. If a game task requests an evaluation for a state that was just evaluated by another parallel game, the result is served directly from the cache. - Exploration vs. Exploitation: In the early moves of a game (controlled by
TEMPERATURE_ANNEALING), the agent chooses its moves probabilistically based on the MCTS policy to encourage exploration. In later moves, it acts greedily, choosing the best-evaluated move to ensure strong play. - Data Augmentation: After each game concludes, every step (board state, MCTS policy, and final game outcome) is added to a
ReplayBuffer. To increase the diversity of the training data, a symmetric version of each game step is also generated and added to the buffer.
Once enough new data has been generated, the training phase begins.
- Candidate Model: A copy of the current best model is created to serve as the "candidate" for this iteration.
- Batch Sampling: For a fixed number of steps (
NUM_TRAIN_STEPS), random batches of game data are sampled from theReplayBuffer. - Loss Calculation: The model's predictions are compared against the "ground truth" data from the buffer:
- Policy Loss: A cross-entropy loss function measures how well the model's predicted policy matches the improved policy generated by MCTS.
- Value Loss: A mean squared error loss measures how accurately the model predicted the final outcome of the game.
- The total loss is a weighted sum of these two components, controlled by
VALUE_LOSS_WEIGHT.
- Cyclical Learning Rate: Instead of a fixed learning rate, the training uses a Cyclical Learning Rate (CLR) schedule with decay. The learning rate oscillates between a base and max value over a cycle of iterations. This technique helps the optimizer to more effectively traverse the loss landscape and avoid getting stuck in local minima. The overall magnitude of the learning rate also decays over long periods to allow for fine-tuning.
- Optimization: The model's weights are updated using the
AdamWoptimizer, which incorporates weight decay and gradient clipping to ensure stable training.
After training, the new candidate model must prove it is stronger than the current champion.
- Validation: If the candidate's win rate against the current best model exceeds a predefined
WIN_RATE_THRESHOLD, it is promoted to become the new best model (deactivated by default). - Elo Benchmarking: If a model is promoted, it then plays another series of games against a fixed-strength
evaluatormodel and other baseline agents (like MiniMax). This calculates a more objective Elo rating for the new model, measuring its progress against a stable benchmark. - Checkpointing: The newly promoted model is saved to a file (e.g.,
iteration_123_elo_1450.mpk).
The training progression of the AlphaZero-based Connect 4 agent can be visualized in the following Elo score plot:
As shown above, the model experiences a rapid improvement in strength during the initial training iterations, quickly surpassing the 500 Elo baseline of the Minimax agent (with a fixed depth of 4), while using 256 simulations (NUM_SIMULATIONS). By iteration 70, the model approaches the Elo rating of the static evaluator (1674 Elo), which represents the strongest model previously obtained via this AlphaZero implementation.
Below is a screenshot of the real-time training interface used to monitor system metrics, losses, and agent statistics:
Before you begin, ensure you have the following installed:
- Rust Toolchain: Install Rust via rustup.
- NVIDIA GPU: The project is configured to use the
Cudabackend from theburnframework. An NVIDIA graphics card is required. It is possible to run this script using theWGPUbackend by modifying themain.rsfile. - CUDA Toolkit & cuDNN: You must have the appropriate NVIDIA CUDA Toolkit and cuDNN libraries installed on your system for
burnto communicate with the GPU if you wish to use it with the defaultCudabackend.
-
Clone the Repository
git clone <your-repository-url> cd <repository-name>
-
Build the Project For optimal performance, especially for training and inference, build the project in release mode:
cargo build --release
The final executable will be located at
target/release/<your-executable-name>.
The application is controlled via subcommands. The general structure for running a command is:
cargo run --release -- <COMMAND> [ARGUMENTS]The -- is important as it separates arguments for Cargo from the arguments for your application.
This command starts the AlphaZero self-play training process. The training loop will run, generating games, training the neural network, and periodically saving model checkpoints.
Command:
cargo run --release -- trainThis command allows you to play a game of Connect 4 against a specified opponent in your terminal.
Command:
cargo run --release -- play <OPPONENT>Arguments:
<OPPONENT>: A string identifying the agent you want to play against. This can be one of three types:- A saved model: The file path to a trained model (e.g.,
.mpkfile). "Human": To play against another human on the same machine."MiniMax": To play against a built-in Minimax agent."Random": To play against an agent that makes random moves.
- A saved model: The file path to a trained model (e.g.,
Examples:
-
Play against a saved model:
cargo run --release -- play "artifacts/models/iteration_100.mpk" -
Play against the Minimax agent:
cargo run --release -- play "MiniMax"
This command runs a tournament between a list of specified players to calculate and compare their relative strength using the Elo rating system.
Command:
cargo run --release -- elo [PLAYERS]... --initial-elo <ELO>Arguments:
-
[PLAYERS]...: A required, space-separated list of one or more player identifiers. Valid identifiers are:- A saved model: The file path to a trained model.
"MiniMax": The built-in Minimax agent."Random": The agent that makes random moves.- Note: The
"Human"player cannot be used in the automated Elo computation.
-
--initial-elo <ELO>: A required flag specifying the starting Elo for all players in the tournament.
Example:
To rate two different trained models against each other, as well as against the Random and MiniMax agents, with a starting Elo of 1200:
cargo run --release -- elo \
"artifacts/models/iteration_90_elo_662.mpk" \
"artifacts/models/iteration_100_elo_604.mpk" \
"Random" \
"MiniMax" \
--initial-elo 1200The codebase is organized into several modules:
| Module | Description |
|---|---|
main.rs |
Handles command-line argument parsing and program entry. |
agent.rs |
Defines the AlphaZero model architecture (residual convolutional network). |
training.rs |
Contains the main AlphaZero training loop: asynchronous self-play, network training, and evaluation. |
ratings.rs |
Implements the Elo rating calculation tournament logic. |
inference.rs |
Contains the logic for the play command. |
validation.rs |
Defines the Player enum and game-playing logic used for evaluation and validation. |
tree.rs |
Implements the Monte Carlo Tree Search (MCTS) algorithm. |
connect4.rs |
Contains the game logic, rules, and board representation for Connect 4. |
memory.rs |
Manages the ReplayBuffer for storing and sampling self-play game data. |
logger.rs |
Implements the TuiLogger, a terminal-based dashboard using ratatui for real-time training monitoring. |
parameters.rs |
Centralizes hyperparameters (learning rates, batch sizes, etc.) and constants for the project. |

