A minimal, modular Graph Neural Networks (GNN) learning project built with PyTorch and PyTorch Geometric. Designed for Apple Silicon (M1/M2/M3) with clear separation of concerns and easy extensibility.
- Clean architecture: Separate modules for data, models, training, and evaluation
- Type-safe: Full type hints throughout the codebase
- Extensible: Easy to add new GNN layers, datasets, and experiments
- Apple Silicon optimized: Uses MPS (Metal Performance Shaders) when available
- Ready-to-run: Includes a working example on the Cora citation network dataset
gnn/
├── gnn_learning/
│ ├── __init__.py # Package initialization
│ ├── data.py # Dataset loading utilities
│ ├── model.py # GNN model definitions (GCN)
│ ├── train.py # Training logic
│ └── evaluate.py # Evaluation utilities
├── main.py # Main entry point
├── pyproject.toml # Project dependencies and config
└── README.md # This file
This project uses uv for fast, reliable Python package management.
curl -LsSf https://astral.sh/uv/install.sh | sh# Create and activate environment
uv venv
source .venv/bin/activate # On macOS/Linux
# Install all dependencies
uv pip install -e .Run the basic GCN training example on the Cora dataset:
python main.pyThis will:
- Load the Cora citation network dataset
- Train a 2-layer Graph Convolutional Network (GCN)
- Evaluate on train/validation/test splits
- Print accuracy results
Expected output:
Using MPS (Apple Silicon GPU)
Loading Cora dataset...
Dataset Statistics:
num_nodes: 2708
num_edges: 10556
num_features: 1433
num_classes: 7
...
Training...
Epoch 10 | Loss: 1.9234
Epoch 20 | Loss: 1.7845
...
Evaluation Results:
train_accuracy: 0.9857
val_accuracy: 0.7880
test_accuracy: 0.8130
from gnn_learning.data import load_cora, dataset_stats
# Load Cora dataset
data = load_cora()
# Get statistics
stats = dataset_stats(data)
print(f"Number of nodes: {stats['num_nodes']}")from gnn_learning.model import GCN
model = GCN(
num_features=1433,
hidden_dim=16,
num_classes=7,
dropout=0.5
)import torch
from gnn_learning.train import train
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
device = torch.device("mps") # or "cpu"
losses = train(
model=model,
data=data,
optimizer=optimizer,
device=device,
epochs=200,
verbose=True
)from gnn_learning.evaluate import evaluate_all_splits
results = evaluate_all_splits(model, data, device)
print(f"Test accuracy: {results['test_accuracy']:.4f}")The modular structure makes it easy to extend:
Create new model classes in gnn_learning/model.py:
from torch_geometric.nn import GATConv, SAGEConv
class GAT(nn.Module):
"""Graph Attention Network"""
def __init__(self, num_features, hidden_dim, num_classes):
super().__init__()
self.conv1 = GATConv(num_features, hidden_dim, heads=8)
self.conv2 = GATConv(hidden_dim * 8, num_classes, heads=1)
# ... implement forward()Add loader functions in gnn_learning/data.py:
def load_pubmed():
dataset = Planetoid(root="./data", name="PubMed")
return dataset[0]The training utilities are composable. Create custom training scripts by mixing and matching:
from gnn_learning.train import train_epoch
from gnn_learning.evaluate import evaluate
for epoch in range(epochs):
loss = train_epoch(model, data, optimizer, device)
if epoch % 10 == 0:
val_acc, _ = evaluate(model, data, device, data.val_mask)
print(f"Epoch {epoch} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")Here are some ideas to explore:
- Message Passing: Implement custom message passing layers using PyTorch Geometric's
MessagePassingbase class - Different Architectures: Try GAT, GraphSAGE, or GIN layers
- Hyperparameter Tuning: Experiment with learning rates, hidden dimensions, dropout
- Visualization: Add node embedding visualization with t-SNE or UMAP
- New Datasets: Try CiteSeer, PubMed, or other graph datasets
- Advanced Training: Add early stopping, learning rate scheduling, or model checkpointing
- Python ≥ 3.10
- PyTorch ≥ 2.0
- PyTorch Geometric ≥ 2.4
- macOS with Apple Silicon (M1/M2/M3) recommended for MPS acceleration
- PyTorch Geometric Documentation
- GCN Paper - Kipf & Welling (2016)
- GNN Primer - Distill.pub overview
MIT