Skip to content

sinan-cakmak/sample-gcn-model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GNN Learning Project

A minimal, modular Graph Neural Networks (GNN) learning project built with PyTorch and PyTorch Geometric. Designed for Apple Silicon (M1/M2/M3) with clear separation of concerns and easy extensibility.

Features

  • Clean architecture: Separate modules for data, models, training, and evaluation
  • Type-safe: Full type hints throughout the codebase
  • Extensible: Easy to add new GNN layers, datasets, and experiments
  • Apple Silicon optimized: Uses MPS (Metal Performance Shaders) when available
  • Ready-to-run: Includes a working example on the Cora citation network dataset

Project Structure

gnn/
├── gnn_learning/
│   ├── __init__.py       # Package initialization
│   ├── data.py           # Dataset loading utilities
│   ├── model.py          # GNN model definitions (GCN)
│   ├── train.py          # Training logic
│   └── evaluate.py       # Evaluation utilities
├── main.py               # Main entry point
├── pyproject.toml        # Project dependencies and config
└── README.md             # This file

Installation

This project uses uv for fast, reliable Python package management.

1. Install uv (if not already installed)

curl -LsSf https://astral.sh/uv/install.sh | sh

2. Create virtual environment and install dependencies

# Create and activate environment
uv venv
source .venv/bin/activate  # On macOS/Linux

# Install all dependencies
uv pip install -e .

Quick Start

Run the basic GCN training example on the Cora dataset:

python main.py

This will:

  1. Load the Cora citation network dataset
  2. Train a 2-layer Graph Convolutional Network (GCN)
  3. Evaluate on train/validation/test splits
  4. Print accuracy results

Expected output:

Using MPS (Apple Silicon GPU)

Loading Cora dataset...

Dataset Statistics:
  num_nodes: 2708
  num_edges: 10556
  num_features: 1433
  num_classes: 7
  ...

Training...
Epoch  10 | Loss: 1.9234
Epoch  20 | Loss: 1.7845
...

Evaluation Results:
  train_accuracy: 0.9857
  val_accuracy: 0.7880
  test_accuracy: 0.8130

Usage Examples

Loading Data

from gnn_learning.data import load_cora, dataset_stats

# Load Cora dataset
data = load_cora()

# Get statistics
stats = dataset_stats(data)
print(f"Number of nodes: {stats['num_nodes']}")

Creating a Model

from gnn_learning.model import GCN

model = GCN(
    num_features=1433,
    hidden_dim=16,
    num_classes=7,
    dropout=0.5
)

Training

import torch
from gnn_learning.train import train

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
device = torch.device("mps")  # or "cpu"

losses = train(
    model=model,
    data=data,
    optimizer=optimizer,
    device=device,
    epochs=200,
    verbose=True
)

Evaluation

from gnn_learning.evaluate import evaluate_all_splits

results = evaluate_all_splits(model, data, device)
print(f"Test accuracy: {results['test_accuracy']:.4f}")

Extending the Project

The modular structure makes it easy to extend:

Adding New GNN Layers

Create new model classes in gnn_learning/model.py:

from torch_geometric.nn import GATConv, SAGEConv

class GAT(nn.Module):
    """Graph Attention Network"""
    def __init__(self, num_features, hidden_dim, num_classes):
        super().__init__()
        self.conv1 = GATConv(num_features, hidden_dim, heads=8)
        self.conv2 = GATConv(hidden_dim * 8, num_classes, heads=1)
    # ... implement forward()

Adding New Datasets

Add loader functions in gnn_learning/data.py:

def load_pubmed():
    dataset = Planetoid(root="./data", name="PubMed")
    return dataset[0]

Custom Training Loops

The training utilities are composable. Create custom training scripts by mixing and matching:

from gnn_learning.train import train_epoch
from gnn_learning.evaluate import evaluate

for epoch in range(epochs):
    loss = train_epoch(model, data, optimizer, device)
    
    if epoch % 10 == 0:
        val_acc, _ = evaluate(model, data, device, data.val_mask)
        print(f"Epoch {epoch} | Loss: {loss:.4f} | Val Acc: {val_acc:.4f}")

Next Steps

Here are some ideas to explore:

  1. Message Passing: Implement custom message passing layers using PyTorch Geometric's MessagePassing base class
  2. Different Architectures: Try GAT, GraphSAGE, or GIN layers
  3. Hyperparameter Tuning: Experiment with learning rates, hidden dimensions, dropout
  4. Visualization: Add node embedding visualization with t-SNE or UMAP
  5. New Datasets: Try CiteSeer, PubMed, or other graph datasets
  6. Advanced Training: Add early stopping, learning rate scheduling, or model checkpointing

Requirements

  • Python ≥ 3.10
  • PyTorch ≥ 2.0
  • PyTorch Geometric ≥ 2.4
  • macOS with Apple Silicon (M1/M2/M3) recommended for MPS acceleration

Resources

License

MIT

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages