From 5996add783211379bb8bf1ce8f3dd800b918f8ab Mon Sep 17 00:00:00 2001 From: Temp 2 Date: Fri, 26 Sep 2025 22:40:14 +0530 Subject: [PATCH 1/2] Add EGSteal implementation --- examples/attack/EGSteal.py | 14 + pygip/datasets/__init__.py | 4 + pygip/datasets/datasets.py | 31 + pygip/models/attack/EGSteal/EGSteal.py | 470 +++++++++++ pygip/models/attack/EGSteal/utils.py | 1040 ++++++++++++++++++++++++ pygip/models/nn/__init__.py | 1 + pygip/models/nn/backbones.py | 308 +++++++ 7 files changed, 1868 insertions(+) create mode 100644 examples/attack/EGSteal.py create mode 100644 pygip/models/attack/EGSteal/EGSteal.py create mode 100644 pygip/models/attack/EGSteal/utils.py diff --git a/examples/attack/EGSteal.py b/examples/attack/EGSteal.py new file mode 100644 index 0000000..0eed82b --- /dev/null +++ b/examples/attack/EGSteal.py @@ -0,0 +1,14 @@ +from pygip.datasets import * +from pygip.models.attack.EGSteal.EGSteal import EGSteal +from pygip.utils.hardware import set_device + +set_device("cuda:0") # cpu, cuda:0 + + +def egsteal(): + dataset = MUTAGGraphClassification(api_type='pyg') + egsteal = EGSteal(dataset,query_shadow_ratio=0.3) + egsteal.attack() + +if __name__ == '__main__': + egsteal() diff --git a/pygip/datasets/__init__.py b/pygip/datasets/__init__.py index 4decc35..24c13a3 100644 --- a/pygip/datasets/__init__.py +++ b/pygip/datasets/__init__.py @@ -7,6 +7,8 @@ Photo, CoauthorCS, CoauthorPhysics, + MUTAG, + MUTAGGraphClassification ) __all__ = [ @@ -18,4 +20,6 @@ 'Photo', 'CoauthorCS', 'CoauthorPhysics', + 'MUTAG', + 'MUTAGGraphClassification' ] diff --git a/pygip/datasets/datasets.py b/pygip/datasets/datasets.py index a5776e3..e7213aa 100644 --- a/pygip/datasets/datasets.py +++ b/pygip/datasets/datasets.py @@ -24,6 +24,8 @@ from torch_geometric.datasets import Reddit from torch_geometric.datasets import TUDataset # ENZYMES +## Added for EGSteal +from torch_geometric.transforms import Constant def dgl_to_tg(dgl_graph): edge_index = torch.stack(dgl_graph.edges()) @@ -565,3 +567,32 @@ def load_dgl_data(self): dataset = YelpDataset(raw_dir=self.path) self.graph_dataset = dataset self.graph_data = dataset[0] + +class MUTAGGraphClassification(Dataset): + def __init__(self, api_type='pyg', path='./data'): + super().__init__(api_type, path) + + def _load_meta_data(self): + if self.api_type == 'pyg': + ds = self.graph_dataset + self.num_features = int(ds.num_node_features) + self.num_classes = int(ds.num_classes) + self.num_graphs = int(len(ds)) + self.num_edge_features = int(ds.num_edge_features) if ds.num_edge_features is not None else 0 + else: + super()._load_meta_data() + + def load_pyg_data(self): + self.dataset_name = 'Mutagenicity' + temp_dataset = TUDataset(root=self.path,name='Mutagenicity') + data_transform = None + if temp_dataset.num_node_features == 0: + print("\nNo node features found. Adding constant node features (all ones).") + data_transform = Constant(value=1, cat=False) + self.graph_dataset = TUDataset(root=self.path,name='Mutagenicity',transform=data_transform) + num_graphs = len(self.graph_dataset) + print(f"\nTotal number of graphs in {self.dataset_name}: {num_graphs}") + print(f"Node features dimension: {self.graph_dataset.num_node_features}") + print(f"Edge features dimension: {self.graph_dataset.num_edge_features if hasattr(self.graph_dataset, 'num_edge_features') else 'N/A'}") + print(f"Number of classes: {self.graph_dataset.num_classes if hasattr(self.graph_dataset, 'num_classes') else 'N/A'}") + return self.graph_dataset diff --git a/pygip/models/attack/EGSteal/EGSteal.py b/pygip/models/attack/EGSteal/EGSteal.py new file mode 100644 index 0000000..fdec560 --- /dev/null +++ b/pygip/models/attack/EGSteal/EGSteal.py @@ -0,0 +1,470 @@ +import torch +import torch.nn.functional as F +from torch.optim import Adam +from torch_geometric.data import DataLoader +from tqdm import tqdm +from torch.utils.data import random_split +from .utils import * +from pygip.models.attack.base import BaseAttack +from pygip.models.nn import SurrogateModelGraphClassification,TargetModelGraphClassification,GCNGraphClassification,GraphSAGEGraphClassification,GATGraphClassification,GINGraphClassification,Classifier,CAM,GradCAM,GradientExplainer +import numpy as np +from torch_geometric.explain import Explainer, GNNExplainer +import os.path as osp +import os +import random +import math +from scipy.stats import kendalltau +from torch_geometric.data import Batch, Data +from collections import defaultdict + +class EGSteal(BaseAttack): + supported_api_types = {"pyg"} + + def __init__(self, dataset, query_shadow_ratio=0.3,gnn_backbone = 'GIN',explanation_mode = 'CAM'): + self.dataset = dataset + self.graph_dataset = dataset.graph_dataset + self.graph_data = dataset.graph_data + self.query_shadow_ratio = query_shadow_ratio + + self.num_graphs = self.dataset.num_graphs + self.num_features = self.dataset.num_features + self.num_classes = self.dataset.num_classes + self.num_edge_features = self.dataset.num_edge_features if hasattr(dataset, 'num_edge_features') else 0 + self.gnn_backbone = gnn_backbone + self.explanation_mode = explanation_mode + + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Fixed parameters + self.seed = 42 + self.batch_size = 64 + self.learning_rate = 0.001 + self.epochs = 200 + self.gat_heads = 4 + self.gnn_layer = 3 + self.gnn_hidden_dim = 128 + self.gnnexplainer_epochs = 100 + self.pgexplainer_epochs = 100 + self.augmentation_ratio = 0.2 + self.operation_ratio = 0.05 + self.align_weight = 1.0 + self.augmentation_type = 'combined' # ['drop_node', 'drop_edge', 'add_edge', 'combined'] + self.shadow_val_ratio = 0.2 + + def prepare_data(self): + set_seed(self.seed) # For reproducibility + # Define split ratios + target_ratio = 0.4 + target_val_ratio = 0.2 + test_ratio = 0.2 + shadow_ratio = 0.4 + target_num = int(self.num_graphs * target_ratio) + test_num = int(self.num_graphs * test_ratio) + shadow_num = self.num_graphs - target_num - test_num # Ensure total consistency + + # Randomly split dataset + target_dataset, test_dataset, shadow_dataset = random_split( + self.graph_dataset, + [target_num, test_num, shadow_num] + ) + + # Further split target_dataset into train and val + target_train_num = int(target_num * (1 - target_val_ratio)) + target_val_num = target_num - target_train_num + + target_train_dataset, target_val_dataset = random_split( + target_dataset, + [target_train_num, target_val_num] + ) + + print("\nDataset split sizes:") + print(f"Target train set size: {len(target_train_dataset)} ({len(target_train_dataset) / self.num_graphs:.1%})") + print(f"Target val set size: {len(target_val_dataset)} ({len(target_val_dataset) / self.num_graphs:.1%})") + print(f"Test set size: {len(test_dataset)} ({len(test_dataset) / self.num_graphs:.1%})") + print(f"Shadow dataset size: {len(shadow_dataset)} ({len(shadow_dataset) / self.num_graphs:.1%})") + + return target_train_dataset, target_val_dataset, shadow_dataset, test_dataset + + def _train_target_model(self,target_train_dataset,target_val_dataset,test_dataset): + # build paths + save_root = './saved_models/EGSteal' + os.makedirs(save_root, exist_ok=True) + model_path = osp.join( + save_root, + f"{self.dataset.__class__.__name__}_{self.gnn_backbone}_{self.explanation_mode}_model.pth" + ) + target_train_loader = DataLoader(target_train_dataset, batch_size=self.batch_size, shuffle=True) + target_val_loader = DataLoader(target_val_dataset, batch_size=self.batch_size, shuffle=False) + test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False) + + #Initialize model + if self.gnn_backbone == 'GIN': + encoder = GINGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer + ).to(self.device) + elif self.gnn_backbone == 'GCN': + encoder = GCNGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer + ).to(self.device) + elif self.gnn_backbone == 'GAT': + encoder = GATGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer, + heads=self.gat_heads + ).to(self.device) + elif self.gnn_backbone == 'GraphSAGE': + encoder = GraphSAGEGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer + ).to(self.device) + else: + raise ValueError(f"Invalid GNN backbone specified: {self.gnn_backbone}. Expected 'GIN', 'GCN', or 'GAT', or 'GraphSAGE'.") + + predictor = Classifier( + input_dim=self.gnn_hidden_dim, + output_dim=self.num_classes + ).to(self.device) + + model = TargetModelGraphClassification(encoder=encoder, predictor=predictor, explanation_mode=self.explanation_mode).to(self.device) + optimizer = Adam(model.parameters(), lr=self.learning_rate) + epochs = self.epochs + best_val_auc = 0.0 + best_model_state = None + + # ---- LOAD if exists (state_dict) ---- + if osp.exists(model_path): + print(f"Loading pre-trained target model weights from {model_path}") + state_dict = torch.load(model_path, map_location=self.device) + model.load_state_dict(state_dict) + model.eval() + return model + + with tqdm(total=epochs, desc='Epochs') as epoch_pbar: + for epoch in range(1, epochs + 1): + train_loss, train_acc, train_auc = train_loop_target_model(model, target_train_loader, optimizer, self.device,self) + val_loss, val_acc, val_auc = evaluate_loop_target_model(model, target_val_loader, self.device,self) + + epoch_pbar.set_postfix({ + 'Train Loss': f'{train_loss:.4f}', + 'Train Acc': f'{train_acc:.4f}', + 'Train AUC': f'{train_auc:.4f}', + 'Val Loss': f'{val_loss:.4f}', + 'Val Acc': f'{val_acc:.4f}', + 'Val AUC': f'{val_auc:.4f}' + }) + epoch_pbar.update(1) + + # Update best model if validation AUC is higher + if val_auc > best_val_auc: + best_val_auc = val_auc + best_model_state = model.state_dict() + # Evaluate the best model on the test set + model.load_state_dict(best_model_state) + test_loss, test_acc, test_auc = evaluate_loop_target_model(model, test_loader, self.device, self) + print(f"Test Accuracy of the best model: {test_acc:.4f}") + print(f"Test AUC of the best model: {test_auc:.4f}") + + # ---- SAVE (ensure directory exists; save file, not directory) ---- + # (we already created save_root above) + torch.save(best_model_state, model_path) + print(f"Saved best target state_dict to {model_path}") + + return model + + def prepare_shadow_data(self,model,shadow_dataset,test_dataset): + if self.explanation_mode == 'GNNExplainer': + gnnexplainer = Explainer( + model=model, + algorithm=GNNExplainer(epochs=self.gnnexplainer_epochs), + explanation_type='model', + model_config=dict( + mode='binary_classification', + task_level='graph', + return_type='raw' + ), + node_mask_type='object', + edge_mask_type=None + ) + + if self.explanation_mode == 'PGExplainer': + pgexplainer = Explainer( + model=model, + algorithm=PGExplainer(epochs=self.pgexplainer_epochs, lr=0.003), + explanation_type='phenomenon', + model_config=dict( + mode='binary_classification', + task_level='graph', + return_type='raw' + ), + node_mask_type=None, + edge_mask_type='object', + ) + + if self.explanation_mode == 'GradCAM': + gradcam = GradCAM(model=model) + + if self.explanation_mode == 'CAM': + cam = CAM(model=model) + + if self.explanation_mode == 'Grad': + grad = GradientExplainer(model=model) + for i in range(2): + if i == 0: + dataset = shadow_dataset + else: + dataset = test_dataset + dataloader = DataLoader(dataset, batch_size=1024, shuffle=False) + + results = [] + total_graph_idx = 0 + + # Iterate through the dataset + for batch_idx, batch_data in enumerate(dataloader): + batch_data = batch_data.to(self.device) + if batch_data.x is None: + batch_data.x = torch.ones((batch_data.num_nodes, 1)).to(self.device) + + batch = batch_data.batch + + # Get model predictions + with torch.no_grad(): + if self.explanation_mode in ['GNNExplainer','PGExplainer']: + out = model(batch_data.x, batch_data.edge_index, batch) + else: + _, out = model(batch_data.x, batch_data.edge_index, batch) + + preds = out.argmax(dim=1) + + # Generate explanations + if self.explanation_mode == 'GNNExplainer': + explanations = gnnexplainer(batch_data.x, batch_data.edge_index, batch=batch) + node_mask = explanations.node_mask.view(-1) + + if self.explanation_mode == 'PGExplainer': + for epoch in range(self.pgexplainer_epochs): + pgexplainer.algorithm = pgexplainer.algorithm.to(self.device) + loss = pgexplainer.algorithm.train(epoch, model, batch_data.x, batch_data.edge_index, target=preds, batch=batch) + explanations = pgexplainer(batch_data.x, batch_data.edge_index, target=preds, batch=batch) + + edge_mask = explanations.edge_mask + edge_index = explanations.edge_index + num_nodes = batch_data.x.shape[0] + + # edge score -> node score + node_mask = convert_edge_scores_to_node_scores(edge_mask, edge_index, num_nodes) + + if self.explanation_mode == 'GradCAM': + explanations = gradcam.get_gradcam_scores(batch_data, preds) + node_mask = explanations + + if self.explanation_mode == 'CAM': + explanations = cam.get_cam_scores(preds, batch) + node_mask = explanations + + if self.explanation_mode == 'Grad': + explanations = grad.get_gradient_scores(batch_data, preds) + node_mask = explanations + + # Split batch data + original_graphs = batch_data.to_data_list() + batch_preds = preds.tolist() + + # Get the number of nodes per graph + num_nodes_per_graph = batch_data.ptr[1:] - batch_data.ptr[:-1] + node_masks_list = torch.split(node_mask, num_nodes_per_graph.tolist()) + + # Process each graph + for idx_in_batch, (original_data, pred, node_m) in enumerate(zip(original_graphs, batch_preds, node_masks_list)): + # Move data back to CPU + original_data = original_data.to('cpu') + node_m = node_m.to('cpu') + + results.append({ + 'original_data': original_data, + 'pred': pred, + 'node_mask': node_m + }) + total_graph_idx += 1 + + print(f"Processed {total_graph_idx}/{len(dataset)} graphs.") + + if i == 0: + queried_shadow_dataset = results + else: + queried_test_dataset = results + return queried_shadow_dataset, queried_test_dataset + + def _train_attack_model(self,queried_shadow,queried_dataset_test): + n = len(queried_shadow) + n_query = max(1, int(round(n * self.query_shadow_ratio))) + + idx = list(range(n)) + rng = random.Random(self.seed) + rng.shuffle(idx) + + qidx = idx[:n_query] # the queried subset + n_val = 0 + if n_query > 1: + n_val = min(max(1, int(round(n_query * self.shadow_val_ratio))), n_query - 1) + + val_idx = set(qidx[:n_val]) + train_idx = qidx[n_val:] + + # If edge case leaves train empty, move one from val → train + if len(train_idx) == 0: + train_idx = [qidx[-1]] + val_idx = set(qidx[:-1]) + + queried_dataset_val = [queried_shadow[i] for i in val_idx] + queried_dataset_train = [queried_shadow[i] for i in train_idx] + + print(f"Shadow Train Dataset Size: {len(queried_dataset_train)}") + print(f"Shadow Val Dataset Size: {len(queried_dataset_val)}") + print(f"Shadow Test Dataset Size: {len(queried_dataset_test)}") + + if self.gnn_backbone == 'GIN': + encoder = GINGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer + ).to(self.device) + elif self.gnn_backbone == 'GCN': + encoder = GCNGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer + ).to(self.device) + elif self.gnn_backbone == 'GAT': + encoder = GATGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer, + heads=self.gat_heads + ).to(self.device) + elif self.gnn_backbone == 'GraphSAGE': + encoder = GraphSAGEGraphClassification( + input_dim=self.num_features, + hidden_dim=self.gnn_hidden_dim, + num_layers=self.gnn_layer + ).to(self.device) + else: + raise ValueError(f"Invalid GNN backbone specified: {self.gnn_backbone}. Expected 'GIN', 'GCN', or 'GAT', or 'GraphSAGE'.") + + predictor = Classifier( + input_dim=self.gnn_hidden_dim, + output_dim=self.num_classes + ).to(self.device) + + model = SurrogateModelGraphClassification(encoder=encoder, predictor=predictor).to(self.device) + + criterion = torch.nn.CrossEntropyLoss() + ranknet_loss_fn = RankNetLoss().to(self.device) + + optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=self.learning_rate) + + processed_train_dataset = process_query_dataset(queried_dataset_train) + processed_val_dataset = process_query_dataset(queried_dataset_val) + processed_test_dataset = process_query_dataset(queried_dataset_test) + + augmentor = DataAugmentor() + + val_loader = DataLoader(processed_val_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=custom_collate) + test_loader = DataLoader(processed_test_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=custom_collate) + + best_val_auc = -math.inf + best_model_state = None + + print("Starting Surrogate Model Training...") + with tqdm(total=self.epochs, desc=f'Training (seed={self.seed})') as epoch_pbar: + for epoch_num in range(1, self.epochs + 1): + # --- 1. Data augmentation --- + augmented_data = augment( + dataset=processed_train_dataset, + augmentor=augmentor, + augmentation_ratio=self.augmentation_ratio, + operation_ratio=self.operation_ratio, + augmentation_type=self.augmentation_type + ) + + combined_train_dataset = processed_train_dataset + augmented_data + + combined_train_loader = DataLoader(combined_train_dataset, + batch_size=self.batch_size, + shuffle=True, + collate_fn=custom_collate) + + # --- 2. Training --- + train_loss_pred, train_ranknet_loss = train( + model, combined_train_loader, optimizer, self.device, + align_weight=self.align_weight, + criterion=criterion, + ranknet_loss_fn=ranknet_loss_fn + ) + + # --- 3. Validation --- + val_acc, val_auc = eval( + model, val_loader, self.device + ) + + epoch_pbar.set_postfix({ + 'Train Pred Loss': f'{train_loss_pred:.4f}', + 'Train RankNet Loss': f'{train_ranknet_loss:.4f}', + 'Val Acc': f'{val_acc:.4f}', + 'Val AUC': f'{val_auc:.4f}' + }) + epoch_pbar.update(1) + + # --- 4. Save best model (based on validation AUC) --- + if not math.isnan(val_auc) and val_auc >= best_val_auc: + best_val_auc = val_auc + best_model_state = model.state_dict() + + # Prepare defaults so we always record something + run_metrics = { + 'seed': self.seed, + 'best_val_auc': float(best_val_auc) if best_val_auc != -math.inf else float('nan'), + 'test_acc': float('nan'), + 'test_auc': float('nan'), + 'fidelity_score': float('nan'), + 'order_accuracy': float('nan'), + 'rank_correlation': float('nan'), + } + + # Evaluate best model + if best_model_state is not None: + model.load_state_dict(best_model_state) + + test_acc, test_auc, fidelity_score, order_accuracy, rank_correlation = test(model, test_loader, self.device) + + print(f"\n[seed={self.seed}] Best Validation AUC: {best_val_auc:.4f}") + print(f"[seed={self.seed}] Test Accuracy: {test_acc:.4f}") + print(f"[seed={self.seed}] Test AUC: {test_auc:.4f}") + print(f"[seed={self.seed}] Fidelity Score: {fidelity_score:.4f}") + print(f"[seed={self.seed}] Order Accuracy: {order_accuracy:.4f}") + print(f"[seed={self.seed}] Rank Correlation: {rank_correlation:.4f}") + + run_metrics.update({ + 'test_acc': float(test_acc), + 'test_auc': float(test_auc), + 'fidelity_score': float(fidelity_score), + 'order_accuracy': float(order_accuracy), + 'rank_correlation': float(rank_correlation), + }) + else: + print(f"[seed={self.seed}] No improvement in validation AUC during training.") + return model + + def attack(self): + target_train_dataset,target_val_dataset,shadow_dataset,test_dataset = self.prepare_data() + target_model = self._train_target_model(target_train_dataset,target_val_dataset,test_dataset) + queried_shadow_dataset,queried_test_dataset = self.prepare_shadow_data(target_model,shadow_dataset,test_dataset) + surrogate_model = self._train_attack_model(queried_shadow_dataset,queried_test_dataset) + + pass + \ No newline at end of file diff --git a/pygip/models/attack/EGSteal/utils.py b/pygip/models/attack/EGSteal/utils.py new file mode 100644 index 0000000..072e2da --- /dev/null +++ b/pygip/models/attack/EGSteal/utils.py @@ -0,0 +1,1040 @@ +# utils.py + +import numpy as np +import random +import os +from typing import Optional +from sklearn.metrics import roc_auc_score +import logging +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.nn import ReLU, Sequential +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.data import Data +from torch_geometric.utils import subgraph +from torch_geometric.explain import Explanation +from torch_geometric.explain.algorithm import ExplainerAlgorithm +from torch_geometric.explain.algorithm.utils import clear_masks, set_masks +from torch_geometric.explain.config import ExplanationType, ModelMode, ModelTaskLevel +from torch_geometric.nn import Linear +from torch_geometric.nn.inits import reset +from torch_geometric.utils import get_embeddings +from torch_geometric.data import Batch, Data +from pygip.models.nn import CAM +from scipy.stats import kendalltau +import math +from scipy.stats import kendalltau +from torch_geometric.data import Batch, Data +from collections import defaultdict + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + +def safe_auc(y_true, y_pred): + """ + Safely compute AUC, avoiding NaN values. + + Args: + y_true: True labels + y_pred: Predicted probabilities + + Returns: + float: AUC value, returns 0.5 if data does not meet computation conditions + """ + y_true = np.array(y_true) + y_pred = np.array(y_pred) + + if len(np.unique(y_true)) == 1: + return 0.5 + + if len(np.unique(y_pred)) == 1: + return 0.5 + + return roc_auc_score(y_true, y_pred) + +def train_loop_target_model(model, dataloader, optimizer, device, args): + """Train the model for one epoch, returning average loss, accuracy, and AUC.""" + model.train() + total_loss = 0 + correct = 0 + total = 0 + all_labels = [] + all_probs = [] + criterion = torch.nn.CrossEntropyLoss() + for data in dataloader: + data = data.to(device) + optimizer.zero_grad() + + if args.explanation_mode in ['GNNExplainer', 'PGExplainer']: + out = model(data.x, data.edge_index, data.batch) + else: + logits, out = model(data.x, data.edge_index, data.batch) + loss = criterion(out, data.y) + loss.backward() + optimizer.step() + total_loss += loss.item() * data.num_graphs + + # Calculate accuracy + pred = out.argmax(dim=1) + correct += pred.eq(data.y).sum().item() + total += data.num_graphs + + all_labels.extend(data.y.cpu().numpy()) + probs = F.softmax(out, dim=1).detach().cpu().numpy() + all_probs.extend(probs) + + avg_loss = total_loss / len(dataloader.dataset) + accuracy = correct / total + + binary_labels = np.array(all_labels) + binary_probs = np.array([prob[1] for prob in all_probs]) + auc = safe_auc(binary_labels, binary_probs) + + return avg_loss, accuracy, auc + + +def evaluate_loop_target_model(model, dataloader, device, args): + """Evaluate the model on the validation or test set, returning average loss, accuracy, and AUC.""" + model.eval() + total_loss = 0 + correct = 0 + total = 0 + all_labels = [] + all_probs = [] + criterion = torch.nn.CrossEntropyLoss() + with torch.no_grad(): + for data in dataloader: + data = data.to(device) + if args.explanation_mode in ['GNNExplainer','PGExplainer']: + out = model(data.x, data.edge_index, data.batch) + else: + _, out = model(data.x, data.edge_index, data.batch) + loss = criterion(out, data.y) + total_loss += loss.item() * data.num_graphs + + pred = out.argmax(dim=1) + correct += pred.eq(data.y).sum().item() + total += data.num_graphs + + all_labels.extend(data.y.cpu().numpy()) + probs = F.softmax(out, dim=1).detach().cpu().numpy() + all_probs.extend(probs) + + avg_loss = total_loss / len(dataloader.dataset) + accuracy = correct / total + + binary_labels = np.array(all_labels) + binary_probs = np.array([prob[1] for prob in all_probs]) + auc = safe_auc(binary_labels, binary_probs) + + return avg_loss, accuracy, auc + +def convert_edge_scores_to_node_scores(edge_mask, edge_index, num_nodes): + """ + Convert edge importance scores to node importance scores. + + Parameters: + - edge_mask (Tensor): Edge importance scores, shape [num_edges]. + - edge_index (Tensor): Edge connections, shape [2, num_edges], indicating the two nodes connected by each edge. + - num_nodes (int): Number of nodes in the graph. + + Returns: + - node_scores (Tensor): Node importance scores, shape [num_nodes]. + """ + # Initialize node importance scores tensor + node_scores = torch.zeros(num_nodes, device=edge_mask.device) # shape: [num_nodes] + + # Initialize node degrees + node_degrees = torch.zeros(num_nodes, device=edge_mask.device) # shape: [num_nodes] + + # Iterate through each edge to calculate the contribution of edge importance to node importance + for i in range(edge_index.shape[1]): + node1, node2 = edge_index[:, i] # Get the two nodes connected by the edge + importance = edge_mask[i] # Get the importance score of the edge + + node_scores[node1] += importance + node_scores[node2] += importance + + node_degrees[node1] += 1 + node_degrees[node2] += 1 + + node_degrees[node_degrees == 0] = 1 # Avoid division by zero + + # Calculate node importance, normalized by node degree + node_scores = node_scores / node_degrees + + return node_scores + +def custom_collate(batch): + """ + Custom collate function to batch PyTorch Geometric Data objects. + + Args: + batch: List[Data] + + Returns: + Batch object + """ + return Batch.from_data_list(batch) + + +def process_query_dataset(query_dataset): + """ + Process query_dataset uniformly, setting the target model prediction as target_pred, + and retaining the original label y and node_mask. + + Parameters: + - query_dataset: List[dict], containing 'original_data', 'pred', and 'node_mask' fields + + Returns: + - processed_data_list: List[Data], each Data object contains original features, y, target_pred, and node_mask + """ + processed_data_list = [] + for sample in query_dataset: + original_data = sample['original_data'] + pred = sample['pred'] + node_mask = sample['node_mask'] + + # Ensure pred is an integer + if isinstance(pred, torch.Tensor): + pred = pred.item() + elif isinstance(pred, (list, np.ndarray)): + pred = pred[0] + + # Create new Data object + new_data = Data( + x=original_data.x, + edge_index=original_data.edge_index, + edge_attr=getattr(original_data, 'edge_attr', None), + y=original_data.y, + target_pred=torch.tensor(pred, dtype=torch.long, device=original_data.x.device), + node_mask=node_mask + ) + + if hasattr(original_data, 'batch'): + new_data.batch = original_data.batch + + processed_data_list.append(new_data) + return processed_data_list + +def augment(dataset, augmentor, augmentation_ratio, operation_ratio=0.1, augmentation_type='combined'): + """ + Generate augmented samples based on the chosen augmentation strategy. + + Parameters: + - dataset: Original training dataset (list of Data objects) + - augmentor: DataAugmentor instance + - augmentation_ratio: Augmentation ratio (e.g., 0.2 means generating augmented samples equal to 20% of original data) + - operation_ratio: Ratio for adding or removing operations + - augmentation_type: Augmentation type, 'drop_node', 'drop_edge', 'add_edge', or 'combined' + + Returns: + - augmented_data_list: List of augmented Data objects + """ + augmented_data_list = [] + num_original_samples = len(dataset) + num_augmented_samples = int(num_original_samples * augmentation_ratio) + + if num_augmented_samples == 0: + return augmented_data_list + + # Group samples by label + label_to_samples = defaultdict(list) + for sample in dataset: + label = sample.y.item() + label_to_samples[label].append(sample) + + # Calculate inverse frequency and normalize + label_weights = {} + total_inverse = 0 + for label, samples in label_to_samples.items(): + Ni = len(samples) + if Ni == 0: + continue + inverse_freq = 1.0 / Ni + label_weights[label] = inverse_freq + total_inverse += inverse_freq + + for label in label_weights: + label_weights[label] /= total_inverse + + # Calculate number of augmented samples per label + num_augmented_samples_per_label = {} + for label, weight in label_weights.items(): + num_aug = int(num_augmented_samples * weight) + num_augmented_samples_per_label[label] = num_aug + + # Allocate remaining samples + remaining = num_augmented_samples - sum(num_augmented_samples_per_label.values()) + labels = list(label_weights.keys()) + for i in range(remaining): + label = labels[i % len(labels)] + num_augmented_samples_per_label[label] += 1 + + # Generate augmented samples + for label, num_aug in num_augmented_samples_per_label.items(): + samples = label_to_samples[label] + if len(samples) < 1: + print(f"Label {label} has insufficient samples to generate augmented samples (requires at least 1 sample).") + continue + + for _ in range(num_aug): + try: + sample = random.choice(samples) + + if augmentation_type == 'drop_node': + augmented_data = augmentor.drop_node(sample, drop_ratio=operation_ratio) + elif augmentation_type == 'drop_edge': + augmented_data = augmentor.drop_edge(sample, drop_ratio=operation_ratio) + elif augmentation_type == 'add_edge': + augmented_data = augmentor.add_edge(sample, add_ratio=operation_ratio) + elif augmentation_type == 'combined': + augmented_data = augmentor.combined_augmentation( + sample, + drop_node_ratio=operation_ratio, + drop_edge_ratio=operation_ratio, + add_edge_ratio=operation_ratio + ) + else: + print(f"Unknown augmentation type: {augmentation_type}") + continue + + if augmented_data is None: + continue + + # Skip single-node graphs + if augmented_data.x.size(0) <= 1: + continue + + augmented_data_list.append(augmented_data) + + except Exception as e: + print(f"Error generating augmented sample for label {label}: {e}") + continue + + return augmented_data_list + + +def train(model, dataloader, optimizer, device, align_weight=1.0, criterion=None, ranknet_loss_fn=None): + """ + Training function, using target_pred as labels and incorporating RankNet loss. + If align_weight is 0, skip the calculation of RankNet loss. + """ + model.train() + total_loss_pred = 0.0 + total_ranknet_loss = 0.0 + total_samples = 0 + + ex = CAM(model) + + for batch_samples in dataloader: + optimizer.zero_grad() + + all_data = batch_samples.to(device) + + node_emb, out_surr = model(all_data.x, all_data.edge_index, all_data.batch) + + loss_pred = criterion(out_surr, all_data.target_pred) + + # Calculate CAM scores + cam_scores = ex.get_cam_scores(all_data.target_pred, all_data.batch) + + node_masks = all_data.node_mask + + batch_ids = all_data.batch # [total_num_nodes] + + ranknet_loss = torch.tensor(0.0, device=device) + if align_weight != 0: + ranknet_loss = ranknet_loss_fn(cam_scores, node_masks, batch_ids) + + # Total loss + total_batch_loss = loss_pred + align_weight * ranknet_loss + + total_batch_loss.backward() + optimizer.step() + + batch_size = len(batch_samples) + total_loss_pred += loss_pred.item() * batch_size + total_ranknet_loss += ranknet_loss.item() * batch_size + total_samples += batch_size + + # Calculate average loss + avg_loss_pred = total_loss_pred / total_samples + avg_ranknet_loss = total_ranknet_loss / total_samples + + return avg_loss_pred, avg_ranknet_loss + + +def eval(model, dataloader, device): + """ + Evaluate the model on the validation set for accuracy and AUC, using target_pred as labels. + + Returns: validation accuracy, validation AUC + """ + model.eval() + total_correct = 0 + total_samples = 0 + all_targets = [] + all_probs = [] + + with torch.no_grad(): + for batch_samples in dataloader: + batch = batch_samples.to(device) + target_preds_tensor = batch.target_pred + + node_emb, out_surr = model(batch.x, batch.edge_index, batch.batch) + pred = out_surr.argmax(dim=1) + + total_correct += pred.eq(target_preds_tensor).sum().item() + total_samples += len(batch_samples) + + all_targets.extend(target_preds_tensor.cpu().numpy()) + if out_surr.size(1) > 1: + all_probs.extend(F.softmax(out_surr, dim=1)[:, 1].cpu().numpy()) + else: + all_probs.extend(torch.sigmoid(out_surr).cpu().numpy()) + + accuracy = total_correct / total_samples + + try: + auc = safe_auc(all_targets, all_probs) + except ValueError: + auc = float('nan') + + return accuracy, auc + + +def calculate_rank_correlation(pred_scores, true_scores, batch_ids): + correlations = [] + for b in torch.unique(batch_ids): + mask = (batch_ids == b) + p = pred_scores[mask].cpu().numpy() + t = true_scores[mask].cpu().numpy() + corr, _ = kendalltau(p, t) + if not np.isnan(corr): + correlations.append(corr) + return np.mean(correlations) + + +def calculate_order_accuracy(pred_scores, true_scores, batch_ids): + """ + Parameters: + - pred_scores: [total_num_nodes], predicted node importance scores + - true_scores: [total_num_nodes], true node importance scores + - batch_ids: [total_num_nodes], graph index for each node + + Returns: + - order_accuracy: Mean order accuracy across all graphs + """ + unique_batch = torch.unique(batch_ids) + per_graph_accuracies = [] + + for b in unique_batch: + mask = (batch_ids == b) + p = pred_scores[mask] + t = true_scores[mask] + num_nodes = p.size(0) + + if num_nodes < 2: + continue + + # Generate all possible node pairs (i, j) where i < j + indices_i, indices_j = torch.triu_indices(num_nodes, num_nodes, offset=1, device=p.device) + + s_i = t[indices_i] + s_j = t[indices_j] + + p_i = p[indices_i] + p_j = p[indices_j] + + # Calculate true relation + # 1 for s_i > s_j, -1 for s_i < s_j, 0 for s_i == s_j + true_relation = torch.where(s_i > s_j, torch.ones_like(s_i), + torch.where(s_i < s_j, torch.ones_like(s_i) * -1, torch.zeros_like(s_i))) + + pred_relation = torch.where(p_i > p_j, torch.ones_like(p_i), + torch.where(p_i < p_j, torch.ones_like(p_i) * -1, torch.zeros_like(p_i))) + + correct = true_relation.eq(pred_relation).float() + correct_total = correct.sum().item() + total = correct.numel() + + if total > 0: + graph_accuracy = correct_total / total + per_graph_accuracies.append(graph_accuracy) + + if len(per_graph_accuracies) == 0: + return float('nan') + + mean_accuracy = sum(per_graph_accuracies) / len(per_graph_accuracies) + + return mean_accuracy + + +def test(model, dataloader, device): + """ + Evaluate the model on the test set for accuracy, AUC, fidelity, order accuracy, and rank correlation. + """ + model.eval() + correct = 0 + total = 0 + all_targets = [] + all_probs = [] + all_predictions = [] + all_target_preds = [] + all_pred_scores = [] + all_true_scores = [] + all_batch_ids = [] + graph_offset = 0 + + ex = CAM(model) + + with torch.no_grad(): + for batch_samples in dataloader: + batch = batch_samples.to(device) + true_labels_tensor = batch.y + target_preds_tensor = batch.target_pred + + node_emb, out_surr = model(batch.x, batch.edge_index, batch.batch) + pred = out_surr.argmax(dim=1) + + correct += pred.eq(true_labels_tensor).sum().item() + total += len(batch_samples) + + all_predictions.append(pred) + all_target_preds.append(target_preds_tensor) + + all_targets.extend(true_labels_tensor.cpu().numpy()) + if out_surr.size(1) > 1: + all_probs.extend(F.softmax(out_surr, dim=1)[:, 1].cpu().numpy()) + else: + all_probs.extend(torch.sigmoid(out_surr).cpu().numpy()) + + cam_scores = ex.get_cam_scores(true_labels_tensor, batch.batch) + all_pred_scores.append(cam_scores.cpu()) + all_true_scores.append(batch.node_mask.cpu()) + all_batch_ids.append(batch.batch.cpu() + graph_offset) + + graph_offset += batch.num_graphs + + # Calculate accuracy + accuracy = correct / total + + # Calculate AUC + try: + auc = safe_auc(all_targets, all_probs) + except ValueError: + auc = float('nan') + + # Calculate fidelity + all_predictions = torch.cat(all_predictions) + all_target_preds = torch.cat(all_target_preds) + total_fidelity = all_target_preds.size(0) + fidelity_score = float('nan') if total_fidelity == 0 else \ + all_predictions.eq(all_target_preds).sum().item() / total_fidelity + + # Calculate rank correlation and order accuracy + all_pred_scores = torch.cat(all_pred_scores) + all_true_scores = torch.cat(all_true_scores) + all_batch_ids = torch.cat(all_batch_ids) + order_accuracy = calculate_order_accuracy(all_pred_scores, all_true_scores, all_batch_ids) + rank_correlation = calculate_rank_correlation(all_pred_scores, all_true_scores, all_batch_ids) + + return accuracy, auc, fidelity_score, order_accuracy, rank_correlation + +class RankNetLoss(nn.Module): + def __init__(self): + super(RankNetLoss, self).__init__() + pass + + def forward(self, pred_scores, true_scores, batch_ids): + """ + Compute RankNet loss for batched data. + + Parameters: + - pred_scores: [total_num_nodes], predicted node importance scores (CAM) + - true_scores: [total_num_nodes], true node importance scores (node_mask) + - batch_ids: [total_num_nodes], graph index for each node + + Returns: + - loss: Average RankNet loss + """ + unique_batch = torch.unique(batch_ids) + total_loss = 0.0 + count = 0 + + for b in unique_batch: + mask = (batch_ids == b) + p = pred_scores[mask] + t = true_scores[mask] + num_nodes = p.size(0) + + if num_nodes < 2: + continue + + indices_i, indices_j = torch.triu_indices(num_nodes, num_nodes, offset=1, device=p.device) + + s_i = t[indices_i] + s_j = t[indices_j] + + p_i = p[indices_i] + p_j = p[indices_j] + + # Label y_ij = 1 if s_i > s_j, else 0, if s_i == s_j then y_ij = 0.5 + y_ij = torch.zeros_like(s_i, dtype=torch.float, device=p.device) + y_ij[s_i > s_j] = 1.0 + y_ij[s_i == s_j] = 0.5 + + surr_diff = p_i - p_j + sigmoid_diff = torch.sigmoid(surr_diff) + loss = F.binary_cross_entropy(sigmoid_diff, y_ij, reduction='mean') + + total_loss += loss + count += 1 + + if count == 0: + return torch.tensor(0.0, device=pred_scores.device, requires_grad=True) + + return total_loss / count + +class DataAugmentor: + def __init__(self): + pass + + def drop_node(self, sample, drop_ratio=0.1) -> Optional[Data]: + """ + Drop nodes with the lowest importance and their associated edges. + + Args: + sample: Input graph data (torch_geometric.data.Data) + drop_ratio: Proportion of nodes to drop + + Returns: + Augmented graph data, or None if invalid + """ + node_mask = sample.node_mask if hasattr(sample, 'node_mask') else torch.ones(sample.num_nodes, dtype=torch.float, device=sample.x.device) + + num_nodes = sample.num_nodes + num_drop = int(drop_ratio * num_nodes) + if num_drop == 0 or num_drop >= num_nodes: + return None + + _, drop_indices = torch.topk(node_mask, k=num_drop, largest=False) + + keep_mask = torch.ones(num_nodes, dtype=torch.bool, device=sample.x.device) + keep_mask[drop_indices] = False + keep_indices = torch.nonzero(keep_mask, as_tuple=False).squeeze() + + new_edge_index, edge_attr, edge_mask = subgraph( + keep_indices, + sample.edge_index, + edge_attr=sample.edge_attr if hasattr(sample, 'edge_attr') else None, + relabel_nodes=True, + num_nodes=num_nodes, + return_edge_mask=True + ) + + new_x = sample.x[keep_indices] + new_node_mask = node_mask[keep_indices] + + # Create augmented graph data + augmented_data = Data( + x=new_x, + edge_index=new_edge_index, + edge_attr=edge_attr, + y=sample.y, + target_pred=sample.target_pred if hasattr(sample, 'target_pred') else None, + node_mask=new_node_mask + ) + + return augmented_data + + def drop_edge(self, sample, drop_ratio=0.1) -> Optional[Data]: + """ + Drop edges between low-importance nodes. + + Args: + sample: Input graph data (torch_geometric.data.Data) + drop_ratio: Proportion of nodes considered low-importance + + Returns: + Augmented graph data, or None if invalid + """ + node_mask = sample.node_mask if hasattr(sample, 'node_mask') else torch.ones(sample.num_nodes, dtype=torch.float, device=sample.x.device) + + num_nodes = sample.num_nodes + num_low_importance = int(drop_ratio * num_nodes) + if num_low_importance < 2: + return None + + _, indices = torch.topk(node_mask, k=num_low_importance, largest=False) + low_importance_nodes = set(indices.cpu().tolist()) + src_nodes = sample.edge_index[0] + dst_nodes = sample.edge_index[1] + + edge_mask = torch.ones(sample.edge_index.size(1), dtype=torch.bool, device=sample.x.device) + for node in low_importance_nodes: + mask = (src_nodes == node) | (dst_nodes == node) + edge_mask = edge_mask & ~mask + + new_edge_index = sample.edge_index[:, edge_mask] + edge_attr = sample.edge_attr[edge_mask] if hasattr(sample, 'edge_attr') and sample.edge_attr is not None else None + + # Create augmented graph data + augmented_data = Data( + x=sample.x, + edge_index=new_edge_index, + edge_attr=edge_attr, + y=sample.y, + target_pred=sample.target_pred if hasattr(sample, 'target_pred') else None, + node_mask=node_mask + ) + + return augmented_data + + def add_edge(self, sample, add_ratio=0.1) -> Optional[Data]: + """ + Add edges between low-importance nodes. + + Args: + sample: Input graph data (torch_geometric.data.Data) + add_ratio: Proportion of nodes considered low-importance + + Returns: + Augmented graph data, or None if invalid + """ + node_mask = sample.node_mask if hasattr(sample, 'node_mask') else torch.ones(sample.num_nodes, dtype=torch.float, device=sample.x.device) + num_nodes = sample.num_nodes + num_low_importance = int(add_ratio * num_nodes) + if num_low_importance < 2: + return None + + _, indices = torch.topk(node_mask, k=num_low_importance, largest=False) + low_importance_nodes = indices.tolist() + + new_edges = [] + for i in range(len(low_importance_nodes)): + for j in range(i + 1, len(low_importance_nodes)): + u, v = low_importance_nodes[i], low_importance_nodes[j] + new_edges.append([u, v]) + new_edges.append([v, u]) + + if not new_edges: + return None + + new_edges_tensor = torch.tensor(new_edges, dtype=torch.long, device=sample.x.device).t() + new_edge_index = torch.cat([sample.edge_index, new_edges_tensor], dim=1) + + # Remove duplicate edges + unique_edges, unique_idx = torch.unique(new_edge_index.t(), dim=0, return_inverse=True) + new_edge_index = unique_edges.t() + + # Handle edge attributes + if hasattr(sample, 'edge_attr') and sample.edge_attr is not None: + num_new_edges = new_edge_index.size(1) - sample.edge_index.size(1) + if num_new_edges > 0: + default_attr = torch.ones((num_new_edges, sample.edge_attr.size(1)), + dtype=sample.edge_attr.dtype, + device=sample.x.device) + edge_attr = torch.cat([sample.edge_attr, default_attr], dim=0) + else: + edge_attr = sample.edge_attr + edge_attr = edge_attr[unique_idx] + else: + edge_attr = None + + # Create augmented graph data + augmented_data = Data( + x=sample.x, + edge_index=new_edge_index, + edge_attr=edge_attr, + y=sample.y, + target_pred=sample.target_pred if hasattr(sample, 'target_pred') else None, + node_mask=node_mask + ) + + return augmented_data + + def combined_augmentation(self, sample, drop_node_ratio=0.1, drop_edge_ratio=0.1, add_edge_ratio=0.1) -> Optional[Data]: + """ + Randomly apply one of the graph augmentation methods (drop node, drop edge, add edge). + + Args: + sample: Input graph data (torch_geometric.data.Data) + drop_node_ratio: Proportion of nodes to drop + drop_edge_ratio: Proportion of nodes for edge dropping + add_edge_ratio: Proportion of nodes for edge addition + + Returns: + Augmented graph data or None (if augmentation fails) + """ + chosen_method = random.choice(['drop_node', 'drop_edge', 'add_edge']) + + try: + if chosen_method == 'drop_node': + return self.drop_node(sample, drop_ratio=drop_node_ratio) + elif chosen_method == 'drop_edge': + return self.drop_edge(sample, drop_ratio=drop_edge_ratio) + elif chosen_method == 'add_edge': + return self.add_edge(sample, add_ratio=add_edge_ratio) + + except Exception as e: + print(f"Error occurred during data augmentation: {str(e)}") + return None + + + + +class PGExplainer(ExplainerAlgorithm): + r"""The PGExplainer model from the `"Parameterized Explainer for Graph + Neural Network" `_ paper. + + Internally, it utilizes a neural network to identify subgraph structures + that play a crucial role in the predictions made by a GNN. + Importantly, the :class:`PGExplainer` needs to be trained via + :meth:`~PGExplainer.train` before being able to generate explanations: + + .. code-block:: python + + explainer = Explainer( + model=model, + algorithm=PGExplainer(epochs=30, lr=0.003), + explanation_type='phenomenon', + edge_mask_type='object', + model_config=ModelConfig(...), + ) + + # Train against a variety of node-level or graph-level predictions: + for epoch in range(30): + for index in [...]: # Indices to train against. + loss = explainer.algorithm.train(epoch, model, x, edge_index, + target=target, index=index) + + # Get the final explanations: + explanation = explainer(x, edge_index, target=target, index=0) + + Args: + epochs (int): The number of epochs to train. + lr (float, optional): The learning rate to apply. + (default: :obj:`0.003`). + **kwargs (optional): Additional hyper-parameters to override default + settings in + :attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`. + """ + + coeffs = { + 'edge_size': 0.05, + 'edge_ent': 1.0, + 'temp': [5.0, 2.0], + 'bias': 0.01, + } + + def __init__(self, epochs: int, lr: float = 0.003, **kwargs): + super().__init__() + self.epochs = epochs + self.lr = lr + self.coeffs.update(kwargs) + + self.mlp = Sequential( + Linear(-1, 64), + ReLU(), + Linear(64, 1), + ) + self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr) + self._curr_epoch = -1 + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + reset(self.mlp) + + def train( + self, + epoch: int, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + *, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ): + r"""Trains the underlying explainer model. + Needs to be called before being able to make predictions. + + Args: + epoch (int): The current epoch of the training phase. + model (torch.nn.Module): The model to explain. + x (torch.Tensor): The input node features of a + homogeneous graph. + edge_index (torch.Tensor): The input edge indices of a homogeneous + graph. + target (torch.Tensor): The target of the model. + index (int or torch.Tensor, optional): The index of the model + output to explain. Needs to be a single index. + (default: :obj:`None`) + **kwargs (optional): Additional keyword arguments passed to + :obj:`model`. + """ + if isinstance(x, dict) or isinstance(edge_index, dict): + raise ValueError(f"Heterogeneous graphs not yet supported in " + f"'{self.__class__.__name__}'") + + if self.model_config.task_level == ModelTaskLevel.node: + if index is None: + raise ValueError(f"The 'index' argument needs to be provided " + f"in '{self.__class__.__name__}' for " + f"node-level explanations") + if isinstance(index, Tensor) and index.numel() > 1: + raise ValueError(f"Only scalars are supported for the 'index' " + f"argument in '{self.__class__.__name__}'") + + z = get_embeddings(model, x, edge_index, **kwargs)[-1] + + self.optimizer.zero_grad() + temperature = self._get_temperature(epoch) + + inputs = self._get_inputs(z, edge_index, index) + logits = self.mlp(inputs).view(-1) + edge_mask = self._concrete_sample(logits, temperature) + set_masks(model, edge_mask, edge_index, apply_sigmoid=True) + + if self.model_config.task_level == ModelTaskLevel.node: + _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, + num_nodes=x.size(0)) + edge_mask = edge_mask[hard_edge_mask] + + y_hat, y = model(x, edge_index, **kwargs), target + + + if index is not None: + y_hat, y = y_hat[index], y[index] + + loss = self._loss(y_hat, y, edge_mask) + loss.backward() + self.optimizer.step() + + clear_masks(model) + self._curr_epoch = epoch + + return float(loss) + + def forward( + self, + model: torch.nn.Module, + x: Tensor, + edge_index: Tensor, + *, + target: Tensor, + index: Optional[Union[int, Tensor]] = None, + **kwargs, + ) -> Explanation: + if isinstance(x, dict) or isinstance(edge_index, dict): + raise ValueError(f"Heterogeneous graphs not yet supported in " + f"'{self.__class__.__name__}'") + + if self._curr_epoch < self.epochs - 1: # Safety check: + raise ValueError(f"'{self.__class__.__name__}' is not yet fully " + f"trained (got {self._curr_epoch + 1} epochs " + f"from {self.epochs} epochs). Please first train " + f"the underlying explainer model by running " + f"`explainer.algorithm.train(...)`.") + + hard_edge_mask = None + if self.model_config.task_level == ModelTaskLevel.node: + if index is None: + raise ValueError(f"The 'index' argument needs to be provided " + f"in '{self.__class__.__name__}' for " + f"node-level explanations") + if isinstance(index, Tensor) and index.numel() > 1: + raise ValueError(f"Only scalars are supported for the 'index' " + f"argument in '{self.__class__.__name__}'") + + # We need to compute hard masks to properly clean up edges and + # nodes attributions not involved during message passing: + _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, + num_nodes=x.size(0)) + + z = get_embeddings(model, x, edge_index, **kwargs)[-1] + + inputs = self._get_inputs(z, edge_index, index) + logits = self.mlp(inputs).view(-1) + + edge_mask = self._post_process_mask(logits, hard_edge_mask, + apply_sigmoid=True) + + return Explanation(edge_mask=edge_mask) + + def supports(self) -> bool: + explanation_type = self.explainer_config.explanation_type + if explanation_type != ExplanationType.phenomenon: + logging.error(f"'{self.__class__.__name__}' only supports " + f"phenomenon explanations " + f"got (`explanation_type={explanation_type.value}`)") + return False + + task_level = self.model_config.task_level + if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}: + logging.error(f"'{self.__class__.__name__}' only supports " + f"node-level or graph-level explanations " + f"got (`task_level={task_level.value}`)") + return False + + node_mask_type = self.explainer_config.node_mask_type + if node_mask_type is not None: + logging.error(f"'{self.__class__.__name__}' does not support " + f"explaining input node features " + f"got (`node_mask_type={node_mask_type.value}`)") + return False + + return True + + ########################################################################### + + def _get_inputs(self, embedding: Tensor, edge_index: Tensor, + index: Optional[int] = None) -> Tensor: + zs = [embedding[edge_index[0]], embedding[edge_index[1]]] + if self.model_config.task_level == ModelTaskLevel.node: + assert index is not None + zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1)) + return torch.cat(zs, dim=-1) + + def _get_temperature(self, epoch: int) -> float: + temp = self.coeffs['temp'] + return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs) + + def _concrete_sample(self, logits: Tensor, + temperature: float = 1.0) -> Tensor: + bias = self.coeffs['bias'] + eps = (1 - 2 * bias) * torch.rand_like(logits) + bias + return (eps.log() - (1 - eps).log() + logits) / temperature + + def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor: + + if self.model_config.mode == ModelMode.binary_classification: + # loss = self._loss_binary_classification(y_hat, y) + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(y_hat, y) + elif self.model_config.mode == ModelMode.multiclass_classification: + loss = self._loss_multiclass_classification(y_hat, y) + elif self.model_config.mode == ModelMode.regression: + loss = self._loss_regression(y_hat, y) + + # Regularization loss: + mask = edge_mask.sigmoid() + size_loss = mask.sum() * self.coeffs['edge_size'] + mask = 0.99 * mask + 0.005 + mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log() + mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent'] + + return loss + size_loss + mask_ent_loss \ No newline at end of file diff --git a/pygip/models/nn/__init__.py b/pygip/models/nn/__init__.py index d40027c..db41bbf 100644 --- a/pygip/models/nn/__init__.py +++ b/pygip/models/nn/__init__.py @@ -1 +1,2 @@ from .backbones import GCN, GraphSAGE, ShadowNet, AttackNet +from .backbones import SurrogateModelGraphClassification,TargetModelGraphClassification,GCNGraphClassification,GraphSAGEGraphClassification,GATGraphClassification,GINGraphClassification,Classifier,CAM,GradCAM,GradientExplainer diff --git a/pygip/models/nn/backbones.py b/pygip/models/nn/backbones.py index 09daf84..190b649 100644 --- a/pygip/models/nn/backbones.py +++ b/pygip/models/nn/backbones.py @@ -4,6 +4,7 @@ from dgl.nn.pytorch import GraphConv, SAGEConv from torch_geometric.nn import GATConv from torch_geometric.nn import GCNConv +from torch_geometric.nn import GINConv, SAGEConv, global_mean_pool class GCN(nn.Module): @@ -121,3 +122,310 @@ def __init__(self, in_channels, hidden_channels, out_channels): def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) return self.conv2(x, edge_index) + +## Graph Classification Models +class SurrogateModelGraphClassification(nn.Module): + def __init__(self, encoder, predictor): + super(SurrogateModelGraphClassification, self).__init__() + self.encoder = encoder + self.predictor = predictor + + def forward(self, x, edge_index, batch): + node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) + out = self.predictor(graph_embeddings) + return node_embeddings, out + + + +class TargetModelGraphClassification(nn.Module): + def __init__(self, encoder, predictor, explanation_mode): + super(TargetModelGraphClassification, self).__init__() + self.encoder = encoder + self.predictor = predictor + self.explanation_mode = explanation_mode + + def forward(self, x, edge_index, batch): + node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) + out = self.predictor(graph_embeddings) + if self.explanation_mode in ['GNNExplainer','PGExplainer']: + return out + return node_embeddings, out + + + +class GCNGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(GCNGraphClassification, self).__init__() + self.convs = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + self.convs.append(GCNConv(input_dim, hidden_dim)) + else: + self.convs.append(GCNConv(hidden_dim, hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class GraphSAGEGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(GraphSAGEGraphClassification, self).__init__() + self.convs = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + self.convs.append(SAGEConv(input_dim, hidden_dim)) + else: + self.convs.append(SAGEConv(hidden_dim, hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class GATGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers, heads=4): + super(GATGraphClassification, self).__init__() + self.convs = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads)) + else: + self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class GINGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(GINGraphClassification, self).__init__() + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + nn_seq = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim) + ) + else: + nn_seq = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim) + ) + self.convs.append(GINConv(nn_seq)) + self.bns.append(nn.BatchNorm1d(hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv, bn in zip(self.convs, self.bns): + x = conv(x, edge_index) + x = bn(x) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class Classifier(nn.Module): + def __init__(self, input_dim, output_dim): + super(Classifier, self).__init__() + self.fc = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.fc(x) + + + +class CAM: + def __init__(self, model): + self.model = model + self.activations = None + self.classifier_weights = None + self._register_hooks() + + def _register_hooks(self): + def forward_hook(module, input, output): + self.activations = output + + classifier = self.model.predictor + + if isinstance(classifier, nn.Linear): + self.classifier_weights = classifier.weight # [num_classes, hidden_dim] + elif hasattr(classifier, 'fc') and isinstance(classifier.fc, nn.Linear): + self.classifier_weights = classifier.fc.weight # [num_classes, hidden_dim] + else: + raise ValueError("Classifier should be an instance of nn.Linear or have a linear layer named 'fc'") + + last_conv = self.model.encoder.convs[-1] + last_conv.register_forward_hook(forward_hook) + + def get_cam_scores(self, target_classes, batch_ids): + """ + Generate CAM scores based on captured activations and classifier weights. + + Parameters: + - target_classes: [batch_size] tensor, each element is a class index + - batch_ids: [total_num_nodes] tensor, indicating the graph index for each node + + Returns: + - cam_scores: list, each element is the CAM score for the corresponding graph [num_nodes_in_graph] + """ + if self.activations is None: + raise ValueError("No activations recorded. Ensure a forward pass has been done before generating CAM.") + + cam_scores = [] + + num_graphs = batch_ids.max().item() + 1 + for graph_id in range(num_graphs): + cls = target_classes[graph_id].item() + weight = self.classifier_weights[cls] # [hidden_dim] w^c_k + + # Get node indices belonging to the current graph + node_indices = (batch_ids == graph_id).nonzero(as_tuple=False).squeeze() + if node_indices.numel() == 0: + cam = torch.tensor([], device=self.activations.device) + else: + activation = self.activations[node_indices] # [num_nodes_in_graph, hidden_dim] F^l_k(X, A) = σ(V F^(l-1)(X, A)W^l_k) + cam = torch.matmul(activation, weight) # [num_nodes_in_graph] L^c_CAM[n] = ReLU(∑_k w^c_k F^L_{k,n}(X, A)) + + if cam.dim() == 0: # single node graph + cam = cam.unsqueeze(0) + + cam_scores.append(cam) + + + cam_scores = torch.cat(cam_scores, dim=0) + + return cam_scores + + + +class GradCAM: + def __init__(self, model): + self.model = model + self.activations = None + self._register_hooks() + + def _register_hooks(self): + def forward_hook(module, input, output): + self.activations = output + + last_conv = self.model.encoder.convs[-1] + last_conv.register_forward_hook(forward_hook) + + def get_gradcam_scores(self, input, target_class): + self.model.zero_grad() + + node_embeddings, output = self.model(input.x, input.edge_index, input.batch) + + batch_size = len(target_class) + scores = output[range(batch_size), target_class] + + gradients = torch.autograd.grad( + scores, + self.activations, + grad_outputs=torch.ones_like(scores), + retain_graph=True + )[0] + + if self.activations is None: + raise ValueError("Activations were not captured") + + weights = self.compute_weights(gradients, input.batch) + + # Calculate Grad-CAM scores + batch = input.batch + num_graphs = batch.max().item() + 1 + gradcam_scores = [] + + for graph_id in range(num_graphs): + mask = batch == graph_id + if mask.any(): + curr_activations = self.activations[mask] # [num_nodes, hidden_dim] + curr_weights = weights[graph_id] # [hidden_dim] + + gradcam = torch.matmul(curr_activations, curr_weights) # [num_nodes] + gradcam = F.relu(gradcam) + + gradcam_scores.append(gradcam) + + return torch.cat(gradcam_scores) + + def compute_weights(self, gradients, batch): + num_graphs = batch.max().item() + 1 + weights = [] + + for graph_id in range(num_graphs): + mask = batch == graph_id + graph_grads = gradients[mask] + alpha = graph_grads.mean(dim=0) + weights.append(alpha) + + return torch.stack(weights) + + + +class GradientExplainer: + def __init__(self, model): + self.model = model + + def get_gradient_scores(self, input_data, target_classes): + input_data.x.requires_grad = True + self.model.zero_grad() + + # Forward + _, output = self.model(input_data.x, input_data.edge_index, input_data.batch) + + # Calculate gradients for each graph + batch = input_data.batch + num_graphs = batch.max().item() + 1 + normalized_scores = [] + + # Calculate gradients for each graph's target class + for graph_id in range(num_graphs): + mask = batch == graph_id + target_class = target_classes[graph_id] + score = output[graph_id, target_class] + + grad = torch.autograd.grad(score, input_data.x, + retain_graph=True, + create_graph=False)[0] # [num_nodes, feature_dim] + + curr_grad = grad[mask] + relu_grad = F.relu(curr_grad) + scores = torch.norm(relu_grad, p=2, dim=1) + + normalized_scores.append(scores) + + return torch.cat(normalized_scores) From 89fad8561ad7b0799b07e77948c9f10e803143e3 Mon Sep 17 00:00:00 2001 From: Bachina-Pranav Date: Thu, 2 Oct 2025 20:00:48 +0530 Subject: [PATCH 2/2] moved graph classification models to EGSteal folder --- pygip/models/attack/EGSteal/EGSteal.py | 2 +- pygip/models/attack/EGSteal/models.py | 314 +++++++++++++++++++++++++ pygip/models/nn/backbones.py | 309 +----------------------- 3 files changed, 316 insertions(+), 309 deletions(-) create mode 100644 pygip/models/attack/EGSteal/models.py diff --git a/pygip/models/attack/EGSteal/EGSteal.py b/pygip/models/attack/EGSteal/EGSteal.py index fdec560..7f23773 100644 --- a/pygip/models/attack/EGSteal/EGSteal.py +++ b/pygip/models/attack/EGSteal/EGSteal.py @@ -6,7 +6,7 @@ from torch.utils.data import random_split from .utils import * from pygip.models.attack.base import BaseAttack -from pygip.models.nn import SurrogateModelGraphClassification,TargetModelGraphClassification,GCNGraphClassification,GraphSAGEGraphClassification,GATGraphClassification,GINGraphClassification,Classifier,CAM,GradCAM,GradientExplainer +from .models import SurrogateModelGraphClassification,TargetModelGraphClassification,GCNGraphClassification,GraphSAGEGraphClassification,GATGraphClassification,GINGraphClassification,Classifier,CAM,GradCAM,GradientExplainer import numpy as np from torch_geometric.explain import Explainer, GNNExplainer import os.path as osp diff --git a/pygip/models/attack/EGSteal/models.py b/pygip/models/attack/EGSteal/models.py new file mode 100644 index 0000000..d78d205 --- /dev/null +++ b/pygip/models/attack/EGSteal/models.py @@ -0,0 +1,314 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from dgl.nn.pytorch import GraphConv, SAGEConv +from torch_geometric.nn import GATConv +from torch_geometric.nn import GCNConv +from torch_geometric.nn import GINConv, SAGEConv, global_mean_pool + +## Graph Classification Models +class SurrogateModelGraphClassification(nn.Module): + def __init__(self, encoder, predictor): + super(SurrogateModelGraphClassification, self).__init__() + self.encoder = encoder + self.predictor = predictor + + def forward(self, x, edge_index, batch): + node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) + out = self.predictor(graph_embeddings) + return node_embeddings, out + + + +class TargetModelGraphClassification(nn.Module): + def __init__(self, encoder, predictor, explanation_mode): + super(TargetModelGraphClassification, self).__init__() + self.encoder = encoder + self.predictor = predictor + self.explanation_mode = explanation_mode + + def forward(self, x, edge_index, batch): + node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) + out = self.predictor(graph_embeddings) + if self.explanation_mode in ['GNNExplainer','PGExplainer']: + return out + return node_embeddings, out + + + +class GCNGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(GCNGraphClassification, self).__init__() + self.convs = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + self.convs.append(GCNConv(input_dim, hidden_dim)) + else: + self.convs.append(GCNConv(hidden_dim, hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class GraphSAGEGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(GraphSAGEGraphClassification, self).__init__() + self.convs = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + self.convs.append(SAGEConv(input_dim, hidden_dim)) + else: + self.convs.append(SAGEConv(hidden_dim, hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class GATGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers, heads=4): + super(GATGraphClassification, self).__init__() + self.convs = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads)) + else: + self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = conv(x, edge_index) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class GINGraphClassification(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super(GINGraphClassification, self).__init__() + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + + for i in range(num_layers): + if i == 0: + nn_seq = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim) + ) + else: + nn_seq = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim) + ) + self.convs.append(GINConv(nn_seq)) + self.bns.append(nn.BatchNorm1d(hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv, bn in zip(self.convs, self.bns): + x = conv(x, edge_index) + x = bn(x) + x = F.relu(x) + + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + + return node_embeddings, graph_embeddings + + + +class Classifier(nn.Module): + def __init__(self, input_dim, output_dim): + super(Classifier, self).__init__() + self.fc = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.fc(x) + + + +class CAM: + def __init__(self, model): + self.model = model + self.activations = None + self.classifier_weights = None + self._register_hooks() + + def _register_hooks(self): + def forward_hook(module, input, output): + self.activations = output + + classifier = self.model.predictor + + if isinstance(classifier, nn.Linear): + self.classifier_weights = classifier.weight # [num_classes, hidden_dim] + elif hasattr(classifier, 'fc') and isinstance(classifier.fc, nn.Linear): + self.classifier_weights = classifier.fc.weight # [num_classes, hidden_dim] + else: + raise ValueError("Classifier should be an instance of nn.Linear or have a linear layer named 'fc'") + + last_conv = self.model.encoder.convs[-1] + last_conv.register_forward_hook(forward_hook) + + def get_cam_scores(self, target_classes, batch_ids): + """ + Generate CAM scores based on captured activations and classifier weights. + + Parameters: + - target_classes: [batch_size] tensor, each element is a class index + - batch_ids: [total_num_nodes] tensor, indicating the graph index for each node + + Returns: + - cam_scores: list, each element is the CAM score for the corresponding graph [num_nodes_in_graph] + """ + if self.activations is None: + raise ValueError("No activations recorded. Ensure a forward pass has been done before generating CAM.") + + cam_scores = [] + + num_graphs = batch_ids.max().item() + 1 + for graph_id in range(num_graphs): + cls = target_classes[graph_id].item() + weight = self.classifier_weights[cls] # [hidden_dim] w^c_k + + # Get node indices belonging to the current graph + node_indices = (batch_ids == graph_id).nonzero(as_tuple=False).squeeze() + if node_indices.numel() == 0: + cam = torch.tensor([], device=self.activations.device) + else: + activation = self.activations[node_indices] # [num_nodes_in_graph, hidden_dim] F^l_k(X, A) = σ(V F^(l-1)(X, A)W^l_k) + cam = torch.matmul(activation, weight) # [num_nodes_in_graph] L^c_CAM[n] = ReLU(∑_k w^c_k F^L_{k,n}(X, A)) + + if cam.dim() == 0: # single node graph + cam = cam.unsqueeze(0) + + cam_scores.append(cam) + + + cam_scores = torch.cat(cam_scores, dim=0) + + return cam_scores + + + +class GradCAM: + def __init__(self, model): + self.model = model + self.activations = None + self._register_hooks() + + def _register_hooks(self): + def forward_hook(module, input, output): + self.activations = output + + last_conv = self.model.encoder.convs[-1] + last_conv.register_forward_hook(forward_hook) + + def get_gradcam_scores(self, input, target_class): + self.model.zero_grad() + + node_embeddings, output = self.model(input.x, input.edge_index, input.batch) + + batch_size = len(target_class) + scores = output[range(batch_size), target_class] + + gradients = torch.autograd.grad( + scores, + self.activations, + grad_outputs=torch.ones_like(scores), + retain_graph=True + )[0] + + if self.activations is None: + raise ValueError("Activations were not captured") + + weights = self.compute_weights(gradients, input.batch) + + # Calculate Grad-CAM scores + batch = input.batch + num_graphs = batch.max().item() + 1 + gradcam_scores = [] + + for graph_id in range(num_graphs): + mask = batch == graph_id + if mask.any(): + curr_activations = self.activations[mask] # [num_nodes, hidden_dim] + curr_weights = weights[graph_id] # [hidden_dim] + + gradcam = torch.matmul(curr_activations, curr_weights) # [num_nodes] + gradcam = F.relu(gradcam) + + gradcam_scores.append(gradcam) + + return torch.cat(gradcam_scores) + + def compute_weights(self, gradients, batch): + num_graphs = batch.max().item() + 1 + weights = [] + + for graph_id in range(num_graphs): + mask = batch == graph_id + graph_grads = gradients[mask] + alpha = graph_grads.mean(dim=0) + weights.append(alpha) + + return torch.stack(weights) + + + +class GradientExplainer: + def __init__(self, model): + self.model = model + + def get_gradient_scores(self, input_data, target_classes): + input_data.x.requires_grad = True + self.model.zero_grad() + + # Forward + _, output = self.model(input_data.x, input_data.edge_index, input_data.batch) + + # Calculate gradients for each graph + batch = input_data.batch + num_graphs = batch.max().item() + 1 + normalized_scores = [] + + # Calculate gradients for each graph's target class + for graph_id in range(num_graphs): + mask = batch == graph_id + target_class = target_classes[graph_id] + score = output[graph_id, target_class] + + grad = torch.autograd.grad(score, input_data.x, + retain_graph=True, + create_graph=False)[0] # [num_nodes, feature_dim] + + curr_grad = grad[mask] + relu_grad = F.relu(curr_grad) + scores = torch.norm(relu_grad, p=2, dim=1) + + normalized_scores.append(scores) + + return torch.cat(normalized_scores) diff --git a/pygip/models/nn/backbones.py b/pygip/models/nn/backbones.py index 190b649..f174740 100644 --- a/pygip/models/nn/backbones.py +++ b/pygip/models/nn/backbones.py @@ -121,311 +121,4 @@ def __init__(self, in_channels, hidden_channels, out_channels): def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) - return self.conv2(x, edge_index) - -## Graph Classification Models -class SurrogateModelGraphClassification(nn.Module): - def __init__(self, encoder, predictor): - super(SurrogateModelGraphClassification, self).__init__() - self.encoder = encoder - self.predictor = predictor - - def forward(self, x, edge_index, batch): - node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) - out = self.predictor(graph_embeddings) - return node_embeddings, out - - - -class TargetModelGraphClassification(nn.Module): - def __init__(self, encoder, predictor, explanation_mode): - super(TargetModelGraphClassification, self).__init__() - self.encoder = encoder - self.predictor = predictor - self.explanation_mode = explanation_mode - - def forward(self, x, edge_index, batch): - node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) - out = self.predictor(graph_embeddings) - if self.explanation_mode in ['GNNExplainer','PGExplainer']: - return out - return node_embeddings, out - - - -class GCNGraphClassification(nn.Module): - def __init__(self, input_dim, hidden_dim, num_layers): - super(GCNGraphClassification, self).__init__() - self.convs = nn.ModuleList() - - for i in range(num_layers): - if i == 0: - self.convs.append(GCNConv(input_dim, hidden_dim)) - else: - self.convs.append(GCNConv(hidden_dim, hidden_dim)) - - def forward(self, x, edge_index, batch): - for conv in self.convs: - x = conv(x, edge_index) - x = F.relu(x) - - node_embeddings = x - graph_embeddings = global_mean_pool(x, batch) - - return node_embeddings, graph_embeddings - - - -class GraphSAGEGraphClassification(nn.Module): - def __init__(self, input_dim, hidden_dim, num_layers): - super(GraphSAGEGraphClassification, self).__init__() - self.convs = nn.ModuleList() - - for i in range(num_layers): - if i == 0: - self.convs.append(SAGEConv(input_dim, hidden_dim)) - else: - self.convs.append(SAGEConv(hidden_dim, hidden_dim)) - - def forward(self, x, edge_index, batch): - for conv in self.convs: - x = conv(x, edge_index) - x = F.relu(x) - - node_embeddings = x - graph_embeddings = global_mean_pool(x, batch) - - return node_embeddings, graph_embeddings - - - -class GATGraphClassification(nn.Module): - def __init__(self, input_dim, hidden_dim, num_layers, heads=4): - super(GATGraphClassification, self).__init__() - self.convs = nn.ModuleList() - - for i in range(num_layers): - if i == 0: - self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads)) - else: - self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads)) - - def forward(self, x, edge_index, batch): - for conv in self.convs: - x = conv(x, edge_index) - x = F.relu(x) - - node_embeddings = x - graph_embeddings = global_mean_pool(x, batch) - - return node_embeddings, graph_embeddings - - - -class GINGraphClassification(nn.Module): - def __init__(self, input_dim, hidden_dim, num_layers): - super(GINGraphClassification, self).__init__() - self.convs = nn.ModuleList() - self.bns = nn.ModuleList() - - for i in range(num_layers): - if i == 0: - nn_seq = nn.Sequential( - nn.Linear(input_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim) - ) - else: - nn_seq = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim) - ) - self.convs.append(GINConv(nn_seq)) - self.bns.append(nn.BatchNorm1d(hidden_dim)) - - def forward(self, x, edge_index, batch): - for conv, bn in zip(self.convs, self.bns): - x = conv(x, edge_index) - x = bn(x) - x = F.relu(x) - - node_embeddings = x - graph_embeddings = global_mean_pool(x, batch) - - return node_embeddings, graph_embeddings - - - -class Classifier(nn.Module): - def __init__(self, input_dim, output_dim): - super(Classifier, self).__init__() - self.fc = nn.Linear(input_dim, output_dim) - - def forward(self, x): - return self.fc(x) - - - -class CAM: - def __init__(self, model): - self.model = model - self.activations = None - self.classifier_weights = None - self._register_hooks() - - def _register_hooks(self): - def forward_hook(module, input, output): - self.activations = output - - classifier = self.model.predictor - - if isinstance(classifier, nn.Linear): - self.classifier_weights = classifier.weight # [num_classes, hidden_dim] - elif hasattr(classifier, 'fc') and isinstance(classifier.fc, nn.Linear): - self.classifier_weights = classifier.fc.weight # [num_classes, hidden_dim] - else: - raise ValueError("Classifier should be an instance of nn.Linear or have a linear layer named 'fc'") - - last_conv = self.model.encoder.convs[-1] - last_conv.register_forward_hook(forward_hook) - - def get_cam_scores(self, target_classes, batch_ids): - """ - Generate CAM scores based on captured activations and classifier weights. - - Parameters: - - target_classes: [batch_size] tensor, each element is a class index - - batch_ids: [total_num_nodes] tensor, indicating the graph index for each node - - Returns: - - cam_scores: list, each element is the CAM score for the corresponding graph [num_nodes_in_graph] - """ - if self.activations is None: - raise ValueError("No activations recorded. Ensure a forward pass has been done before generating CAM.") - - cam_scores = [] - - num_graphs = batch_ids.max().item() + 1 - for graph_id in range(num_graphs): - cls = target_classes[graph_id].item() - weight = self.classifier_weights[cls] # [hidden_dim] w^c_k - - # Get node indices belonging to the current graph - node_indices = (batch_ids == graph_id).nonzero(as_tuple=False).squeeze() - if node_indices.numel() == 0: - cam = torch.tensor([], device=self.activations.device) - else: - activation = self.activations[node_indices] # [num_nodes_in_graph, hidden_dim] F^l_k(X, A) = σ(V F^(l-1)(X, A)W^l_k) - cam = torch.matmul(activation, weight) # [num_nodes_in_graph] L^c_CAM[n] = ReLU(∑_k w^c_k F^L_{k,n}(X, A)) - - if cam.dim() == 0: # single node graph - cam = cam.unsqueeze(0) - - cam_scores.append(cam) - - - cam_scores = torch.cat(cam_scores, dim=0) - - return cam_scores - - - -class GradCAM: - def __init__(self, model): - self.model = model - self.activations = None - self._register_hooks() - - def _register_hooks(self): - def forward_hook(module, input, output): - self.activations = output - - last_conv = self.model.encoder.convs[-1] - last_conv.register_forward_hook(forward_hook) - - def get_gradcam_scores(self, input, target_class): - self.model.zero_grad() - - node_embeddings, output = self.model(input.x, input.edge_index, input.batch) - - batch_size = len(target_class) - scores = output[range(batch_size), target_class] - - gradients = torch.autograd.grad( - scores, - self.activations, - grad_outputs=torch.ones_like(scores), - retain_graph=True - )[0] - - if self.activations is None: - raise ValueError("Activations were not captured") - - weights = self.compute_weights(gradients, input.batch) - - # Calculate Grad-CAM scores - batch = input.batch - num_graphs = batch.max().item() + 1 - gradcam_scores = [] - - for graph_id in range(num_graphs): - mask = batch == graph_id - if mask.any(): - curr_activations = self.activations[mask] # [num_nodes, hidden_dim] - curr_weights = weights[graph_id] # [hidden_dim] - - gradcam = torch.matmul(curr_activations, curr_weights) # [num_nodes] - gradcam = F.relu(gradcam) - - gradcam_scores.append(gradcam) - - return torch.cat(gradcam_scores) - - def compute_weights(self, gradients, batch): - num_graphs = batch.max().item() + 1 - weights = [] - - for graph_id in range(num_graphs): - mask = batch == graph_id - graph_grads = gradients[mask] - alpha = graph_grads.mean(dim=0) - weights.append(alpha) - - return torch.stack(weights) - - - -class GradientExplainer: - def __init__(self, model): - self.model = model - - def get_gradient_scores(self, input_data, target_classes): - input_data.x.requires_grad = True - self.model.zero_grad() - - # Forward - _, output = self.model(input_data.x, input_data.edge_index, input_data.batch) - - # Calculate gradients for each graph - batch = input_data.batch - num_graphs = batch.max().item() + 1 - normalized_scores = [] - - # Calculate gradients for each graph's target class - for graph_id in range(num_graphs): - mask = batch == graph_id - target_class = target_classes[graph_id] - score = output[graph_id, target_class] - - grad = torch.autograd.grad(score, input_data.x, - retain_graph=True, - create_graph=False)[0] # [num_nodes, feature_dim] - - curr_grad = grad[mask] - relu_grad = F.relu(curr_grad) - scores = torch.norm(relu_grad, p=2, dim=1) - - normalized_scores.append(scores) - - return torch.cat(normalized_scores) + return self.conv2(x, edge_index) \ No newline at end of file