本文档提供 CentriLearn 所有公共 API 的详细参考。
- 算法 API (algorithms)
- 环境 API (environments)
- 模型 API (models)
- 缓冲区 API (buffer)
- 指标 API (metrics)
- 工具 API (utils)
所有强化学习算法的基类。
from centrilearn.algorithms import BaseAlgorithm
class BaseAlgorithm__init__(model, optimizer_cfg, scheduler_cfg=None, replaybuffer_cfg=None, metric_manager_cfg=None, device='cpu')
初始化算法。
Parameters:
model(nn.Module): 模型实例optimizer_cfg(dict): 优化器配置scheduler_cfg(dict, optional): 学习率调度器配置replaybuffer_cfg(dict, optional): 经验缓冲区配置metric_manager_cfg(dict, optional): 指标管理器配置device(str): 运行设备 ('cpu' 或 'cuda')
Raises:
TypeError: 如果 optimizer_cfg 不是字典
将模型设置为训练模式。
Returns: None
将模型设置为评估模式。
Returns: None
保存训练检查点。
Parameters:
path(str): 保存路径**kwargs: 额外保存的信息(如 episode, best_reward 等)
Returns: None
加载训练检查点。
Parameters:
path(str): 检查点文件路径
Returns:
dict: 检查点数据
Raises:
FileNotFoundError: 如果文件不存在
更新学习率调度器。
Parameters:
metrics(dict, optional): 评估指标(用于 ReduceLROnPlateau)
Returns: None
获取当前学习率。
Returns:
float: 当前学习率
获取模型。
Returns:
nn.Module: 模型实例
从配置构建模型(抽象方法,子类必须实现)。
Parameters:
model_cfg(dict): 模型配置
Returns:
nn.Module: 模型实例
执行一步训练(抽象方法,子类必须实现)。
Parameters:
batch(dict): 训练批次数据
Returns:
dict: 训练损失信息
更新模型参数(抽象方法,子类必须实现)。
Returns:
dict: 更新信息
选择动作(抽象方法,子类必须实现)。
Parameters:
state(dict): 当前状态**kwargs: 额外参数
Returns:
any: 选择的动作
收集经验到缓冲区(抽象方法,子类必须实现)。
Parameters:
state(dict): 当前状态*args,**kwargs: 经验数据
Returns: None
Deep Q-Network 算法。
class DQN(BaseAlgorithm)__init__(model_cfg, optimizer_cfg, scheduler_cfg=None, replaybuffer_cfg=None, metric_manager_cfg=None, algo_cfg=None, device='cpu')
初始化 DQN 算法。
Parameters:
model_cfg(dict): 模型配置optimizer_cfg(dict): 优化器配置scheduler_cfg(dict, optional): 学习率调度器配置replaybuffer_cfg(dict, optional): 经验缓冲区配置metric_manager_cfg(dict, optional): 指标管理器配置algo_cfg(dict, optional): DQN 特定配置gamma(float): 折扣因子,默认 0.99epsilon_start(float): 初始探索率,默认 1.0epsilon_end(float): 最终探索率,默认 0.01epsilon_decay(int): 探索率衰减步数,默认 10000tau(float): 软更新系数,默认 0.005rcst_coef(float): 重建损失系数,默认 0.0001
device(str): 运行设备
计算当前探索率。
Returns:
float: 当前 epsilon 值
使用 epsilon-greedy 策略选择动作。
Parameters:
states(dict): 当前状态epsilon(float, optional): 指定 epsilon,None 则自动计算
Returns:
tuple: (actions, epsilon)
从缓冲区采样并更新模型。
Parameters:
batch_size(int): 批次大小
Returns:
dict: 包含以下键:q_loss: Q 损失reconstruction_loss: 重建损失total_loss: 总损失
获取 Q 值。
Parameters:
state(dict): 当前状态
Returns:
torch.Tensor: Q 值张量
收集经验到缓冲区。
Parameters:
states(dict): 状态actions(torch.Tensor): 动作rewards(torch.Tensor): 奖励next_states(dict): 下一状态dones(torch.Tensor): 终止标志
Returns: None
Proximal Policy Optimization 算法。
class PPO(BaseAlgorithm)__init__(model_cfg, optimizer_cfg, scheduler_cfg=None, replaybuffer_cfg=None, metric_manager_cfg=None, algo_cfg=None, device='cpu')
初始化 PPO 算法。
Parameters:
model_cfg(dict): 模型配置optimizer_cfg(dict): 优化器配置scheduler_cfg(dict, optional): 学习率调度器配置replaybuffer_cfg(dict, optional): 经验缓冲区配置metric_manager_cfg(dict, optional): 指标管理器配置algo_cfg(dict, optional): PPO 特定配置gamma(float): 折扣因子,默认 0.99gae_lambda(float): GAE lambda 参数,默认 0.95clip_epsilon(float): PPO 裁剪参数,默认 0.2entropy_coef(float): 熵正则化系数,默认 0.01value_coef(float): 价值损失系数,默认 0.5max_grad_norm(float): 最大梯度裁剪,默认 0.5num_epochs(int): 每次更新的 epoch 数,默认 1rcst_coef(float): 重建损失系数,默认 0.0001
device(str): 运行设备
选择动作。
Parameters:
state(dict): 当前状态deterministic(bool): 是否确定性选择
Returns:
tuple: (action, log_prob, value)
更新模型。
Parameters:
batch_size(int): 批次大小
Returns:
dict: 包含以下键:policy_loss: 策略损失value_loss: 价值损失entropy_loss: 熵损失total_loss: 总损失
获取动作和价值(推理模式)。
Parameters:
state(dict): 当前状态
Returns:
tuple: (action, value)
收集经验到缓冲区。
Parameters:
state(dict): 状态action(torch.Tensor): 动作log_prob(torch.Tensor): 对数概率reward(torch.Tensor): 奖励done(torch.Tensor): 终止标志value(torch.Tensor): 价值
Returns: None
所有环境的基类。
class BaseEnv__init__(graph=None, synth_type='ba', synth_args=None, node_features='ones', use_component=False, is_undirected=True, device='cpu')
初始化环境。
Parameters:
graph(nx.Graph, optional): 网络图对象synth_type(str): 合成图类型 ('ba', 'er', 'ws')synth_args(dict, optional): 合成图参数node_features(str): 节点特征类型 ('ones', 'degree', 'combin')use_component(bool): 是否使用连通分量is_undirected(bool): 是否无向图device(str): 计算设备
重置环境。
Parameters:
graph(nx.Graph, optional): 新的图对象
Returns:
dict: 初始状态,包含:edge_index: 边索引张量node_features: 节点特征node_mask: 节点掩码num_nodes: 节点数reward_info: 奖励信息
获取当前状态。
Returns:
dict: 当前状态
获取 PyG 格式数据。
Parameters:
mask(torch.Tensor, optional): 节点掩码
Returns:
torch_geometric.data.Data: PyG 数据对象
计算连通分量。
Parameters:
edge_index(torch.Tensor): 边索引num_nodes(int): 节点数
Returns:
list: 连通分量列表
检查图是否为空。
Returns:
bool: 是否为空
执行一步动作(抽象方法)。
Parameters:
action(int): 动作mapping(dict): 节点映射
Returns:
tuple: (next_state, reward, done, info)
网络瓦解环境。
class NetworkDismantlingEnv(BaseEnv)__init__(graph=None, synth_type='ba', synth_args=None, node_features='ones', value_type='auc', use_gcc=False, use_component=False, is_undirected=True, device='cpu')
初始化网络瓦解环境。
Parameters:
graph(nx.Graph, optional): 网络图对象synth_type(str): 合成图类型synth_args(dict, optional): 合成图参数node_features(str): 节点特征类型value_type(str): 奖励类型 ('auc', 'ar')use_gcc(bool): 只与最大连通分支交互use_component(bool): 是否使用连通分量is_undirected(bool): 是否无向图device(str): 计算设备
执行一步动作。
Parameters:
action(int): 要移除的节点索引mapping(dict): 节点映射
Returns:
tuple: (next_state, reward, done, info)next_state: 下一状态reward: 奖励done: 终止标志info: 额外信息lcc_size: 当前最大连通分量大小attack_rate: 攻击率remaining_nodes: 剩余节点数
返回剩余图的最大连通分量大小。
Returns:
int: 最大连通分量大小
返回最大连通分量的节点索引。
Returns:
list: 节点索引列表
移除节点。
Parameters:
node(int): 节点索引mapping(dict): 节点映射
Returns: None
向量化环境。
class VectorizedEnv初始化向量化环境。
Parameters:
env_class: 环境类env_kwargs(list): 环境参数列表env_num(int, optional): 环境数量
从单个配置创建多个副本。
Parameters:
env_class: 环境类env_kwargs(dict): 环境参数env_num(int): 环境数量
Returns:
VectorizedEnv: 向量化环境实例
从图列表创建向量化环境。
Parameters:
env_class: 环境类graph_list(list): 图列表common_kwargs(dict, optional): 通用参数
Returns:
VectorizedEnv: 向量化环境实例
重置环境。
Parameters:
indices(list, optional): 要重置的环境索引
Returns:
list: 观测列表
批量执行动作。
Parameters:
actions(list): 动作列表
Returns:
tuple: (observations, rewards, dones, infos)
返回环境数量。
Returns:
int: 环境数量
获取单个环境。
Parameters:
index(int): 环境索引
Returns:
BaseEnv: 环境实例
GraphSAGE 主干网络。
class GraphSAGE__init__(in_channels, hidden_channels, num_layers, output_dim=None, aggr='mean', graph_aggr='add', norm=None, dropout=0.0)
初始化 GraphSAGE。
Parameters:
in_channels(int): 输入特征维度hidden_channels(int): 隐藏特征维度num_layers(int): GNN 层数output_dim(int, optional): 输出维度aggr(str): 聚合方式 ('mean', 'max', 'sum')graph_aggr(str): 图池化方式 ('add', 'mean', 'max')norm(str, optional): 归一化方式 ('batch', 'layer')dropout(float): Dropout 概率
前向传播。
Parameters:
data(dict): 输入数据,包含:node_features: 节点特征edge_index: 边索引
Returns:
dict: 输出,包含:node_embed: 节点嵌入graph_embed: 图嵌入
Graph Attention Network。
class GAT__init__(in_channels, hidden_channels, num_layers, output_dim=None, aggr='mean', graph_aggr='add', norm=None, dropout=0.0, v2=False, heads=1, concat=True)
初始化 GAT。
Parameters:
in_channels(int): 输入特征维度hidden_channels(int): 隐藏特征维度num_layers(int): GNN 层数output_dim(int, optional): 输出维度aggr(str): 聚合方式graph_aggr(str): 图池化方式norm(str, optional): 归一化方式dropout(float): Dropout 概率v2(bool): 使用 GATv2heads(int): 注意力头数concat(bool): 是否拼接多头
Graph Isomorphism Network。
class GIN__init__(in_channels, hidden_channels, num_layers, output_dim=None, aggr='mean', graph_aggr='add', norm=None, dropout=0.0)
初始化 GIN。
Parameters: 与 GraphSAGE 相同
ResNet 风格深度网络。
class DeepNet__init__(in_channels, hidden_channels=64, num_blocks=3, block_config=None, aggr='mean', graph_aggr='add', norm='layer', dropout=0.0, use_residual=True, output_dim=None, nn='GraphSAGE')
初始化 DeepNet。
Parameters:
in_channels(int): 输入特征维度hidden_channels(int): 隐藏特征维度num_blocks(int): Block 数量block_config(dict, optional): Block 配置aggr(str): GraphSAGE 聚合方式graph_aggr(str): 图池化方式norm(str): 归一化类型dropout(float): Dropout 概率use_residual(bool): 是否使用残差连接output_dim(int, optional): 输出维度nn(str): 基础 GNN 类型
Feature Pyramid Network。
class FPNet__init__(in_channels, hidden_channels_list=[64, 128, 256], num_layers_list=None, aggr='mean', graph_aggr='add', norm='layer', dropout=0.0, fusion_mode='add', output_dim=None, nn='GraphSAGE')
初始化 FPNet。
Parameters:
in_channels(int): 输入特征维度hidden_channels_list(list): 各层隐藏维度列表num_layers_list(list, optional): 各层 GNN 层数列表aggr(str): GraphSAGE 聚合方式graph_aggr(str): 图池化方式norm(str): 归一化类型dropout(float): Dropout 概率fusion_mode(str): 特征融合方式 ('add', 'concat', 'attention')output_dim(int, optional): 输出维度nn(str): 基础 GNN 类型
Q 值预测头。
class QHead__init__(in_channels, hidden_layers=None, activation='leaky_relu', dropout=0.0)
初始化 QHead。
Parameters:
in_channels(int): 输入特征维度hidden_layers(list, optional): 隐藏层维度列表activation(str): 激活函数dropout(float): Dropout 概率
前向传播。
Parameters:
node_embed(torch.Tensor): 节点嵌入
Returns:
torch.Tensor: Q 值
策略预测头。
class PolicyHead__init__(in_channels, hidden_layers=None, activation='leaky_relu', dropout=0.0)
初始化 PolicyHead。
Parameters: 与 QHead 相同
前向传播。
Parameters:
node_embed(torch.Tensor): 节点嵌入
Returns:
torch.Tensor: 动作 logits
价值预测头。
class VHead__init__(in_channels, hidden_layers=None, activation='leaky_relu', dropout=0.0)
初始化 VHead。
Parameters: 与 QHead 相同
前向传播。
Parameters:
node_embed(torch.Tensor): 节点嵌入
Returns:
torch.Tensor: 价值估计
Q 网络。
class Qnet初始化 Qnet。
Parameters:
backbone_cfg(dict): 主干网络配置q_head_cfg(dict, optional): Q 值头配置
前向传播。
Parameters:
data(dict): 输入数据
Returns:
dict: 包含以下键:q_values: Q 值node_embed: 节点嵌入graph_embed: 图嵌入
Actor-Critic 网络。
class ActorCritic初始化 ActorCritic。
Parameters:
backbone_cfg(dict): 主干网络配置actor_head_cfg(dict, optional): Actor 头配置critic_head_cfg(dict, optional): Critic 头配置num_critics(int): Critic 数量
前向传播。
Parameters:
data(dict): 输入数据
Returns:
dict: 包含以下键:logit: 动作 logitsv_values: 价值估计
经验回放缓冲区。
class ReplayBuffer__init__(capacity, n_step=1, gamma=0.99, alpha=0.6, beta_start=0.4, beta_frames=100000, epsilon=1e-6, prioritized=False)
初始化 ReplayBuffer。
Parameters:
capacity(int): 缓冲区容量n_step(int): N 步回报步数gamma(float): 折扣因子alpha(float): 优先度指数beta_start(float): 重要性采样初始 betabeta_frames(int): beta 衰减帧数epsilon(float): 最小优先度prioritized(bool): 是否使用优先级采样
添加经验。
Parameters:
state(dict): 状态action(torch.Tensor): 动作reward(torch.Tensor): 奖励next_state(dict): 下一状态done(torch.Tensor): 终止标志
Returns: None
采样一批数据。
Parameters:
batch_size(int): 批次大小
Returns:
tuple: (batch, indices, weights)batch: 批次数据indices: 采样索引weights: 重要性采样权重
更新优先级。
Parameters:
indices(list): 索引列表priorities(list): 优先度列表
Returns: None
获取当前 beta 值。
Returns:
float: 当前 beta
获取缓冲区大小。
Returns:
int: 当前经验数量
清空缓冲区。
Returns: None
轨迹缓冲区。
class RolloutBuffer初始化 RolloutBuffer。
Parameters:
capacity(int): 缓冲区容量
添加经验。
Parameters:
state(dict): 状态action(torch.Tensor): 动作log_prob(torch.Tensor): 对数概率reward(torch.Tensor): 奖励done(torch.Tensor): 终止标志value(torch.Tensor): 价值
Returns: None
获取训练批次(计算 GAE 优势)。
Parameters:
batch_size(int): 批次大小gamma(float): 折扣因子gae_lambda(float): GAE lambda 参数
Returns:
list: 批次列表,每个批次包含:states,actions,log_probs,rewards,dones,valuesadvantages,returns
清空缓冲区。
Returns: None
获取缓冲区大小。
Returns:
int: 当前经验数量
所有指标的基类。
class BaseMetric初始化指标。
Parameters:
name(str, optional): 指标名称record(str): 记录方式 ('max' 或 'min')
更新指标累积值。
Parameters:
value(float): 指标值
Returns: None
重置指标状态。
Returns: None
获取指标结果。
Returns:
dict: 包含以下键:value: 当前值max: 最大值min: 最小值count: 计数history: 历史值
处理单个步骤(抽象方法)。
Parameters:
state(dict): 状态action: 动作reward(float): 奖励next_state(dict): 下一状态done(bool): 终止标志info(dict, optional): 额外信息
Returns:
float: 指标值
在完整 episode 上评估(抽象方法)。
Parameters:
env: 环境model: 模型num_episodes(int): Episode 数量
Returns:
dict: 评估结果
计算当前累积值(抽象方法)。
Returns:
float: 当前值
Attack Curve 面积指标。
class AUC(BaseMetric)初始化 AUC 指标。
处理单个步骤。
Returns:
float: 累积奖励
攻击率指标。
class AttackRate(BaseMetric)初始化 AttackRate 指标。
指标管理器。
class MetricManager初始化指标管理器。
Parameters:
metrics(list, optional): 指标列表save_dir(str, optional): 保存目录log_interval(int): 日志打印间隔
添加单个指标。
Parameters:
metric(BaseMetric): 指标实例
Returns: None
添加多个指标。
Parameters:
metrics(list): 指标列表
Returns: None
移除指标。
Parameters:
name(str): 指标名称
Returns:
bool: 是否成功移除
获取指标实例。
Parameters:
name(str): 指标名称
Returns:
BaseMetric: 指标实例
更新所有指标。
Parameters:
state(dict): 状态action: 动作reward(float): 奖励next_state(dict): 下一状态done(bool): 终止标志info(dict, optional): 额外信息
Returns:
dict: 更新结果
评估所有指标。
Parameters:
env: 环境model: 模型num_episodes(int): Episode 数量
Returns:
dict: 评估结果
获取所有指标结果。
Returns:
dict: 指标结果字典
获取摘要(仅当前值)。
Returns:
dict: 摘要字典
重置所有指标。
Returns: None
重置指定指标。
Parameters:
name(str): 指标名称
Returns: None
保存指标结果。
Parameters:
path(str, optional): 保存路径
Returns: None
加载指标结果。
Parameters:
path(str): 文件路径
Returns: None
打印指标日志。
Parameters:
step(int, optional): 当前步数prefix(str): 日志前缀
Returns: None
返回指标数量。
Returns:
int: 指标数量
注册器类。
class Registry初始化注册器。
Parameters:
name(str): 注册器名称
获取注册的类。
Parameters:
key(str): 类名字符串
Returns:
type: 对应的类
Raises:
KeyError: 如果 key 不存在
注册模块(装饰器或函数)。
Parameters:
name(str, optional): 注册名称force(bool): 是否覆盖已存在的类module(type, optional): 要注册的类
Returns:
- 装饰器函数或注册后的类
返回注册数量。
Returns:
int: 注册数量
构建优化器。
Parameters:
model(nn.Module): 模型cfg(dict): 优化器配置type: 优化器类型 ('Adam', 'AdamW', 'SGD', 'RMSprop', 'Adagrad', 'Adadelta')lr: 学习率weight_decay: 权重衰减- 其他优化器特定参数
Returns:
torch.optim.Optimizer: 优化器实例
构建学习率调度器。
Parameters:
optimizer(torch.optim.Optimizer): 优化器cfg(dict, optional): 调度器配置type: 调度器类型- 调度器特定参数
Returns:
torch.optim.lr_scheduler._LRScheduler: 调度器实例
支持的调度器类型:
StepLR,MultiStepLR,ExponentialLRCosineAnnealingLR,CosineAnnealingWarmRestartsReduceLROnPlateau,LinearLR,CyclicLROneCycleLR,LambdaLR,MultiplicativeLRConstantLR,SequentialLR,ChainedScheduler
从配置构建模块。
Parameters:
cfg(dict): 配置字典type: 类名字符串或类- 其他初始化参数
registry(Registry): 注册器default_args(dict, optional): 默认参数
Returns:
- 构建的对象实例
构建主干网络。
Parameters:
cfg(dict or list): 主干网络配置default_args(dict, optional): 默认参数
Returns:
- 主干网络实例
构建预测头。
Parameters:
cfg(dict or list): 预测头配置default_args(dict, optional): 默认参数
Returns:
- 预测头实例
构建网络瓦解模型。
Parameters:
cfg(dict or list): 模型配置default_args(dict, optional): 默认参数
Returns:
- 模型实例
构建环境。
Parameters:
cfg(dict): 环境配置type: 环境类型- 其他环境参数
env_num(optional): 向量化环境数量graph_list(optional): 图列表env_kwargs_list(optional): 配置列表
default_args(dict, optional): 默认参数
Returns:
- 环境实例(单环境或向量化环境)
构建算法。
Parameters:
cfg(dict or list): 算法配置default_args(dict, optional): 默认参数
Returns:
- 算法实例
构建经验缓冲区。
Parameters:
cfg(dict or list): 缓冲区配置default_args(dict, optional): 默认参数
Returns:
- 缓冲区实例
构建指标。
Parameters:
cfg(dict or list): 指标配置default_args(dict, optional): 默认参数
Returns:
- 指标实例或实例列表
构建指标管理器。
Parameters:
cfg(dict, optional): 指标管理器配置metrics: 指标配置列表save_dir: 保存目录log_interval: 日志间隔
Returns:
MetricManager: 指标管理器实例
从配置训练。
Parameters:
config(dict): 配置字典algorithm: 算法配置environment: 环境配置training: 训练参数
verbose(bool): 是否打印日志**kwargs: 额外参数
Returns:
tuple: (results, algorithm)results(dict): 训练结果algorithm: 训练完成的算法实例
以下注册器在 src.utils.registry 中预定义:
NN # 基础图神经网络层
BACKBONES # 主干网络
HEADS # 预测头
NETWORK_DISMANTLER # 网络瓦解模型
ENVIRONMENTS # 环境
ALGORITHMS # 算法
REPLAYBUFFERS # 经验缓冲区
METRICS # 指标GATGINGraphSAGE
GATGINGraphSAGEDeepNetFPNet
MLPHeadQHeadPolicyHeadVHeadDuelingHeadComponentValueHead
QnetActorCritic
NetworkDismantlingEnv
DQNPPO
ReplayBufferRolloutBuffer
AUCAttackRate