A reinforcement learning framework based on graph neural networks for solving combinatorial optimization problems in complex networks, such as network dismantling.
Core Framework
- ✅ Modular architecture (separation of environments, algorithms, and models)
- ✅ Registry mechanism for dynamic component registration
- ✅ Configurable training system
- ✅ Multiple graph neural network backbones (GraphSAGE, GAT, GIN, etc.)
- ✅ Flexible prediction head system (QHead, VHead, LogitHead, etc.)
Reinforcement Learning Algorithms
- ✅ DQN (Deep Q-Network) implementation
- ✅ PPO (Proximal Policy Optimization) implementation
- ✅ Support for experience replay buffers (standard/prioritized)
- ✅ PPO rollout buffer (RolloutBuffer)
Complex Network Tasks: Network Dismantling
- ✅ Network dismantling environment (NetworkDismantlingEnv)
- ✅ Synthetic graph generation (BA, ER, etc.)
- ✅ Real-world network dataset support
Performance Optimization
- ✅ Vectorized environment (VectorizedEnv) supporting parallel training
- ✅ Vectorized environment configuration file (dqn_vectorized.yaml)
- 🔄 More reinforcement learning algorithms (A3C, SAC, TD3)
- 🔄 More application scenarios
- 🔄 More training tools
- 🔄 Distributed training support
- 🔄 Documentation improvement and performance optimization
- 🔄 Large-scale testing and evaluation
There are many combinatorial optimization problems in graph theory, such as network dismantling and graph partitioning, which are NP-Hard problems. Research on these problems has often relied on heuristic algorithms with handcrafted features. In recent years, an increasing number of studies have used deep reinforcement learning methods to solve these combinatorial optimization problems and achieved significant results.
Currently, there are many mature frameworks in the fields of graph neural networks and reinforcement learning, such as PyG (PyTorch Geometric) and SB3 (Stable Baselines3), but specialized frameworks for graph reinforcement learning remain absent. Due to the uniqueness of graph data (node connections, graph structure changes, etc.), extending existing reinforcement learning frameworks poses significant challenges. Therefore, this project aims to establish a reinforcement learning framework for graph data to facilitate learning and experimentation for relevant researchers.
I have previously conducted research on complex networks, and my thesis topic is graph reinforcement learning. Therefore, I developed this project to help me complete my thesis. This is also my first open-source project, and I hope to provide valuable tools to the community.
- Graph Data Focused: Reinforcement learning framework for graph data based on PyTorch Geometric
- Modular Design: Clear separation of environment, algorithm, and model components for easy extension and combination
- Registry Mechanism: Flexible component registration and dynamic building, similar to mmcv's configuration style
- Configurable Training: Start training with one click through configuration files without modifying code
- Easy to Extend: Easily register custom components through decorators, easily extendable to different complex network sequential decision-making tasks
- Modules Guide - Comprehensive guide for using different modules (algorithms, environments, models, buffers, metrics)
- API Reference - Detailed API documentation for all public interfaces
- Examples - Example scripts demonstrating various use cases:
- DQN Example - DQN training examples
- PPO Example - PPO training examples
- Python >= 3.11
- CUDA >= 11.8 (GPU training recommended)
# Clone the project
git clone https://github.com/He-JiYe/CentriLearn.git
cd CentriLearn
# Install core dependencies
pip install -e .
# Install all dependencies (recommended)
pip install -e ".[all]"# Install PyTorch (select according to your CUDA version)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# Install PyTorch Geometric
pip install torch-geometric torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
# Install other dependencies
pip install networkx numpy pyyaml tqdm
⚠️ Note: This project is still under active development. The following content provides basic usage examples. More detailed documentation, tutorials, and API references will be provided in the future.
We provide a convenient command-line tool to start training directly through YAML configuration files:
# Basic training
python tools/train.py configs/network_dismantling/dqn.yaml
# Enable logging
python tools/train.py configs/network_dismantling/dqn.yaml --use_logging --log_dir ./logs/train
# Specify checkpoint save directory
python tools/train.py configs/network_dismantling/dqn.yaml --ckpt_dir ./checkpoints
# Resume training from checkpoint
python tools/train.py configs/network_dismantling/dqn.yaml --resume ./checkpoints/checkpoint_episode_500.pth
# Customize training parameters
python tools/train.py configs/network_dismantling/ppo.yaml --num_episodes 500 --batch_size 64 --save_interval 50import yaml
from centrilearn.utils import train_from_cfg
# Load configuration file
with open('configs/network_dismantling/dqn.yaml', 'r') as f:
config = yaml.safe_load(f)
# Start training
results, algorithm = train_from_cfg(config, verbose=True)
# Access training results
print(f"Average reward: {results['avg_reward']:.4f}")
print(f"Total episodes: {results['total_episodes']}")import networkx as nx
from centrilearn.utils import build_environment, build_algorithm
# Create custom environment
graph = nx.barabasi_albert_graph(n=50, m=2)
env = build_environment({
'type': 'NetworkDismantlingEnv',
'graph': graph,
'node_features': 'combin'
})
# Build algorithm
algo = build_algorithm({
'type': 'DQN',
'model': {
'type': 'Qnet',
'backbone_cfg': {
'type': 'GraphSAGE',
'in_channels': 2,
'hidden_channels': 64,
'num_layers': 3
},
'q_head_cfg': {
'type': 'QHead',
'in_channels': 64
}
},
'optimizer_cfg': {
'type': 'Adam',
'lr': 0.0001
},
'algo_cfg': {
'gamma': 0.99,
'epsilon_decay': 10000
},
'device': 'cuda'
})
# Train
results = algo._run_training_loop(env, {
'num_episodes': 1000,
'batch_size': 32,
'log_interval': 10,
'ckpt_dir': './checkpoints',
'save_interval': 100
})CentriLearn uses YAML/JSON format configuration files with highly flexible configuration. Specific parameters match the model requirements and can be determined by checking the model code. Below is a sample YAML configuration file:
algorithm:
type: DQN # Algorithm type: DQN | PPO
model:
type: Qnet # Model type
backbone_cfg: # Backbone network config
type: GraphSAGE # Supports multiple GNNs
in_channels: 2
hidden_channels: 64
num_layers: 3
q_head_cfg: # Q-value prediction head
type: QHead
in_channels: 64
optimizer_cfg: # Optimizer config
type: Adam
lr: 0.0001
weight_decay: 0.0005
replaybuffer_cfg: # Experience replay buffer
type: PrioritizedReplayBuffer
capacity: 10000
metric_manager_cfg: # Metric manager
save_dir: ./logs/metrics
log_interval: 10
metrics:
- type: AUC # Giant connected component area
record: min
- type: AttackRate # Attack rate
record: min
algo_cfg: # Algorithm hyperparameters
gamma: 0.99
epsilon_start: 1.0
epsilon_end: 0.01
epsilon_decay: 10000
tau: 0.005
device: cuda
environment:
type: NetworkDismantlingEnv # Environment type
synth_type: ba # Synthetic graph type
synth_args:
min_n: 30
max_n: 50
m: 4
node_features: combin # Node feature type
is_undirected: True
value_type: ar # Reward type: ar (attack rate)
use_gcc: False
use_component: False
device: cuda
training:
num_episodes: 1000 # Number of training episodes
max_steps: 1000 # Max steps per episode
batch_size: 32 # Batch size
log_interval: 10 # Log print interval
eval_interval: 100 # Evaluation interval
eval_episodes: 5 # Number of evaluation episodes
ckpt_dir: ./checkpoints # Checkpoint save directory
save_interval: 100 # Checkpoint save interval
resume: null # Resume pathDQN: Deep Q-NetworkPPO: Proximal Policy Optimization
GraphSAGE: Graph SAGEGAT: Graph Attention NetworkGIN: Graph Isomorphism NetworkDeepNet: Deep Graph Neural NetworkFPNet: Feature Pyramid Graph Neural Network
QHead: Q-value prediction headVHead: Value prediction headLogitHead: Policy prediction headPolicyHead: Policy head
NetworkDismantlingEnv: Network dismantling environmentVectorizedEnv: Vectorized environment (parallel training)
ReplayBuffer: Standard experience replayPrioritizedReplayBuffer: Prioritized experience replayRolloutBuffer: PPO rollout buffer
Using vectorized environments can significantly improve training efficiency by running multiple environment instances simultaneously:
from centrilearn.environments import VectorizedEnv
# Create vectorized environment
env = VectorizedEnv({
'env_kwargs': {
'type': 'NetworkDismantlingEnv',
'synth_type': 'ba',
'synth_args': {'min_n': 30, 'max_n': 50, 'm': 4},
# ...
},
'env_num': 4 # 4 parallel environments
})
# Training automatically detects and uses vectorized mode
results = algo._run_training_loop(env, training_cfg)Or in configuration file:
environment:
type: VectorizedEnv
env_kwargs:
type: NetworkDismantlingEnv
synth_type: ba
# ...
env_num: 4Checkpoints are automatically saved during training and can be resumed from:
# Automatically save during training
python tools/train.py configs/dqn.yaml --ckpt_dir ./checkpoints
# Resume after interruption
python tools/train.py configs/dqn.yaml --resume ./checkpoints/checkpoint_episode_500.pthSaved checkpoints include:
- Model parameters (
model_state_dict) - Optimizer state (
optimizer_state_dict) - Learning rate scheduler state (
scheduler_state_dict) - Training steps (
training_step) - Training progress and statistics
Built-in multiple evaluation metrics automatically record the training process:
metric_manager_cfg:
save_dir: ./logs/metrics
log_interval: 10
metrics:
- type: AUC # Area under giant connected component curve
record: min
- type: AttackRate # Attack rate
record: min
- type: EpisodeReward # Cumulative reward
record: maxMetric history is automatically saved as JSON files for subsequent analysis.
CentriLearn/
├── configs/ # Configuration files
│ └── network_dismantling/ # Network dismantling configs
│ ├── dqn.yaml
│ ├── ppo.yaml
│ └── dqn_vectorized.yaml
├── ckpt/ # Model weights
├── data/ # Datasets
│ ├── small/ # Small-scale networks
│ └── large/ # Large-scale networks
├── docs/ # Documentation
├── logs/ # Logs
├── notebooks/ # Jupyter notebooks
├── centrilearn/ # Source code
│ ├── algorithms/ # RL algorithms
│ │ ├── base.py # Algorithm base class
│ │ ├── dqn.py # DQN implementation
│ │ └── ppo.py # PPO implementation
│ ├── buffer/ # Experience buffers
│ │ ├── base.py
│ │ ├── replaybuffer.py
│ │ └── rolloutbuffer.py
│ ├── environments/ # Environment implementations
│ │ ├── base.py
│ │ ├── network_dismantling.py
│ │ └── vectorized_env.py
│ ├── metrics/ # Evaluation metrics
│ │ ├── base.py
│ │ ├── manager.py
│ │ └── network_dismantling_metrics.py
│ ├── models/ # Model components
│ │ ├── backbones/ # Backbone networks
│ │ │ ├── GraphSAGE.py
│ │ │ ├── GAT.py
│ │ │ ├── GIN.py
│ │ │ ├── DeepNet.py
│ │ │ └── FPNet.py
│ │ ├── heads/ # Prediction heads
│ │ │ ├── q_head.py
│ │ │ ├── v_head.py
│ │ │ ├── logit_head.py
│ │ │ └── policy_head.py
│ │ ├── network_dismantler/ # Complete models
│ │ │ ├── Qnet.py
│ │ │ └── ActorCritic.py
│ │ └── loss/ # Loss functions
│ │ └── restruct_loss.py
│ └── utils/ # Utilities
│ ├── builder.py # Component builder
│ ├── registry.py # Registry
│ └── train.py # Training entry
├── tests/ # Tests
├── tools/ # Tools
│ └── train.py # Training script
├── pyproject.toml # Project configuration
├── README.md # English documentation
└── README_CN.md # Chinese documentation
Contributions are welcome! Please follow these steps:
- Fork this project
- Create a feature branch (
git checkout -b feature/AmazingFeature) - Commit your changes (
git commit -m 'Add some AmazingFeature') - Push to the branch (
git push origin feature/AmazingFeature) - Submit a Pull Request
- Format code with Black:
black centrilearn/ - Sort imports with isort:
isort centrilearn/ - Run tests:
pytest - Check types:
mypy centrilearn/
A: You can load real-world network data and create an environment:
import networkx as nx
from centrilearn.utils import build_environment
# Load network data
graph = nx.read_edgelist('data/my_network.edgelist')
# Create environment
env = build_environment({
'type': 'NetworkDismantlingEnv',
'graph': graph,
'node_features': 'combin'
})A: Try the following methods to improve training speed:
- Use vectorized environments for parallel training
- Increase
batch_size - Use GPU training (
device: cuda) - Reduce model complexity
We will further optimize project performance in the future.
A: Use the registry decorator to register your algorithm:
from centrilearn.utils import ALGORITHMS
@ALGORITHMS.register_module()
class MyAlgorithm(BaseAlgorithm):
def __init__(self, ...):
# Implement your algorithm
passThen use it in the configuration file:
algorithm:
type: MyAlgorithm
# ...A: Load a checkpoint and evaluate on the test set:
from centrilearn.utils import build_algorithm
# Build algorithm
algo = build_algorithm(algorithm_cfg)
# Load checkpoint
algo.load_checkpoint('checkpoints/model_best.pth')
# Set to evaluation mode
algo.set_eval_mode()
# Evaluate in test environment
# ...This project is licensed under the MIT License. See the LICENSE file for details.
- Project Homepage: https://github.com/He-JiYe/CentriLearn
- Issue Reporting: Issues
- Email: 202200820169@mail.sdu.edu.cn
If this project helps you, please give us a ⭐️!