From fea4b34eebbf9e503de784763efa0a2ac48bacc6 Mon Sep 17 00:00:00 2001 From: Gabriella Munoz Date: Tue, 13 Jan 2026 19:55:09 -0500 Subject: [PATCH] feat(attack): integrate EGSteal into PyGIP (PyG + TU datasets) --- .gitignore | 18 + examples/egsteal_attack.py | 18 + pygip/datasets/__init__.py | 1 + pygip/datasets/datasets.py | 86 +++- pygip/models/attack/__init__.py | 60 +-- pygip/models/attack/egsteal/__init__.py | 4 + pygip/models/attack/egsteal/attack.py | 550 +++++++++++++++++++++++ pygip/models/attack/egsteal/eg_models.py | 259 +++++++++++ pygip/models/attack/egsteal/eg_utils.py | 345 ++++++++++++++ 9 files changed, 1301 insertions(+), 40 deletions(-) create mode 100644 examples/egsteal_attack.py create mode 100644 pygip/models/attack/egsteal/__init__.py create mode 100644 pygip/models/attack/egsteal/attack.py create mode 100644 pygip/models/attack/egsteal/eg_models.py create mode 100644 pygip/models/attack/egsteal/eg_utils.py diff --git a/.gitignore b/.gitignore index e0afa39..ea92871 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,21 @@ dist/ #virtual environments folder .venv + +# Local datasets / artifacts (do not commit) +data/ +*.pt +*.pth +*.zip + +# accidental local files +2.2.0 + +# Local datasets / artifacts (do not commit) +data/ +*.pt +*.pth +*.zip + +# accidental local files +2.2.0 diff --git a/examples/egsteal_attack.py b/examples/egsteal_attack.py new file mode 100644 index 0000000..579ebc5 --- /dev/null +++ b/examples/egsteal_attack.py @@ -0,0 +1,18 @@ +from pygip.datasets import TUGraph +from pygip.models.attack import EGStealAttack + +dataset = TUGraph(name="NCI109", api_type="pyg") + +config = { + "gnn_backbone": "GIN", + "gnn_layers": 3, + "hidden_dim": 128, + "epochs": 5, # set to 200 later + "batch_size": 64, + "explanation_mode": "CAM", + "align_weight": 1.0, +} + +attack = EGStealAttack(dataset, config=config) +print(attack.attack()) + diff --git a/pygip/datasets/__init__.py b/pygip/datasets/__init__.py index 4decc35..204540a 100644 --- a/pygip/datasets/__init__.py +++ b/pygip/datasets/__init__.py @@ -7,6 +7,7 @@ Photo, CoauthorCS, CoauthorPhysics, + TUGraph, ) __all__ = [ diff --git a/pygip/datasets/datasets.py b/pygip/datasets/datasets.py index a5776e3..52dfe99 100644 --- a/pygip/datasets/datasets.py +++ b/pygip/datasets/datasets.py @@ -1,17 +1,44 @@ -import dgl +# --- Optional DGL imports (only required when api_type == "dgl") --- +try: + import dgl # type: ignore + from dgl import DGLGraph # type: ignore + + # DGL datasets (graph classification / node classification) + from dgl.data import ( # type: ignore + AmazonCoBuyComputerDataset, # Amazon-Computer + AmazonCoBuyPhotoDataset, # Amazon-Photo + CoauthorCSDataset, # Coauthor-CS + CoauthorPhysicsDataset, # Coauthor-Physics + CoraGraphDataset, + CiteseerGraphDataset, + PubmedGraphDataset, + ) +except ImportError: + dgl = None + DGLGraph = None + + AmazonCoBuyComputerDataset = None + AmazonCoBuyPhotoDataset = None + CoauthorCSDataset = None + CoauthorPhysicsDataset = None + CoraGraphDataset = None + CiteseerGraphDataset = None + PubmedGraphDataset = None +# --------------------------------------------------------------- + +try: + import dgl # optional: only needed when api_type == "dgl" +except ImportError: + dgl = None import numpy as np import torch -from dgl import DGLGraph -from dgl.data import AmazonCoBuyComputerDataset # Amazon-Computer -from dgl.data import AmazonCoBuyPhotoDataset # Amazon-Photo -from dgl.data import CoauthorCSDataset, CoauthorPhysicsDataset -from dgl.data import FakeNewsDataset -from dgl.data import FlickrDataset -from dgl.data import GINDataset -from dgl.data import MUTAGDataset -from dgl.data import RedditDataset -from dgl.data import YelpDataset -from dgl.data import citation_graph # Cora, CiteSeer, PubMed +try: + import dgl # optional + from dgl import DGLGraph # type: ignore +except ImportError: + dgl = None + DGLGraph = None # type: ignore + from sklearn.model_selection import StratifiedShuffleSplit from torch_geometric.data import Data as PyGData from torch_geometric.datasets import Amazon # Amazon Computers, Photo @@ -62,6 +89,13 @@ class Dataset(object): def __init__(self, api_type='dgl', path='./data'): assert api_type in {'dgl', 'pyg'}, 'API type must be dgl or pyg' self.api_type = api_type + if self.api_type == "dgl" and dgl is None: + raise ImportError( + "DGL is not installed, but api_type='dgl' was requested. " + "Install DGL (or run on a platform that supports DGL wheels) " + "or use api_type='pyg'." + ) + self.path = path self.dataset_name = self.get_name() @@ -260,6 +294,34 @@ def __repr__(self): f"#Nodes={self.num_nodes}, #Features={self.num_features}, " f"#Classes={self.num_classes})") +class TUGraph(Dataset): + """ + PyG wrapper for TU graph classification datasets (e.g., NCI109, AIDS, Mutagenicity). + This is graph-level classification, so we set graph_dataset and do not use graph_data. + """ + def __init__(self, name: str, api_type: str = "pyg", path: str = "./data"): + self.name = name + super().__init__(api_type=api_type, path=path) + + def get_name(self): + return self.name + + def load_dgl_data(self): + raise ImportError("TUGraph only supports api_type='pyg' (DGL not required).") + + def load_pyg_data(self): + # torch_geometric.datasets.TUDataset is already imported at top + self.graph_dataset = TUDataset(root=self.path, name=self.name) + self.graph_data = None # graph classification datasets are list-like + + def _load_meta_data(self): + # Override because base _load_meta_data assumes a single PyGData in self.graph_data + # For TU datasets, metadata comes from the dataset object. + self.num_nodes = 0 # varies per graph + self.num_features = self.graph_dataset.num_features + self.num_classes = self.graph_dataset.num_classes + + class Cora(Dataset): def __init__(self, api_type='dgl', path='./data'): diff --git a/pygip/models/attack/__init__.py b/pygip/models/attack/__init__.py index 8cbeb84..5b2a1a5 100644 --- a/pygip/models/attack/__init__.py +++ b/pygip/models/attack/__init__.py @@ -1,31 +1,35 @@ -from .AdvMEA import AdvMEA -from .CEGA import CEGA -from .DataFreeMEA import ( - DFEATypeI, - DFEATypeII, - DFEATypeIII -) -from .mea.MEA import ( - ModelExtractionAttack0, - ModelExtractionAttack1, - ModelExtractionAttack2, - ModelExtractionAttack3, - ModelExtractionAttack4, - ModelExtractionAttack5 -) -from .Realistic import RealisticAttack +""" +Attack module exports. + +Some attacks depend on DGL. DGL wheels may be unavailable on some platforms (e.g., macOS). +We import DGL-dependent attacks conditionally so PyG-only workflows still work. +""" + +from .egsteal import EGStealAttack + +# Optional: if you KNOW any of these are PyG-only, you can import them here. +# For now, we keep everything else behind the DGL gate to avoid import-time crashes. + + +try: + import dgl # noqa: F401 + + # Import ALL attacks that require DGL here: + from .AdvMEA import AdvMEA + from .CEGA import CEGA + from .DataFreeMEA import DataFreeMEA + from .Realistic import Realistic + +except ImportError: + AdvMEA = None + CEGA = None + DataFreeMEA = None + Realistic = None __all__ = [ - 'AdvMEA', - 'CEGA', - 'RealisticAttack', - 'DFEATypeI', - 'DFEATypeII', - 'DFEATypeIII', - 'ModelExtractionAttack0', - 'ModelExtractionAttack1', - 'ModelExtractionAttack2', - 'ModelExtractionAttack3', - 'ModelExtractionAttack4', - 'ModelExtractionAttack5', + "EGStealAttack", + "AdvMEA", + "CEGA", + "DataFreeMEA", + "Realistic", ] diff --git a/pygip/models/attack/egsteal/__init__.py b/pygip/models/attack/egsteal/__init__.py new file mode 100644 index 0000000..6c00d14 --- /dev/null +++ b/pygip/models/attack/egsteal/__init__.py @@ -0,0 +1,4 @@ +from .attack import EGStealAttack + +__all__ = ["EGStealAttack"] + diff --git a/pygip/models/attack/egsteal/attack.py b/pygip/models/attack/egsteal/attack.py new file mode 100644 index 0000000..7d4478f --- /dev/null +++ b/pygip/models/attack/egsteal/attack.py @@ -0,0 +1,550 @@ +from __future__ import annotations + +from pygip.models.attack.base import BaseAttack +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import math +import numpy as np +import torch +import torch.nn.functional as F +from torch.optim import Adam +from torch.utils.data import random_split +from torch_geometric.data import Batch, Data +from torch_geometric.loader import DataLoader +from torch_geometric.explain import Explainer, GNNExplainer +from torch_geometric.explain.config import ModelConfig, ModelMode, ModelTaskLevel + +# 🔧 If your grep showed BaseAttack in a different path, change this import: +from pygip.models.attack.base import BaseAttack + +from .eg_models import ( + CAM, + GAT, + GCN, + GIN, + GraphSAGE, + Classifier, + GradientExplainer, + GradCAM, + SurrogateModel, + TargetModel, +) +from .eg_utils import DataAugmentor, PGExplainer, RankNetLoss, safe_auc, set_seed + + +def custom_collate(batch: List[Data]) -> Batch: + return Batch.from_data_list(batch) + + +def process_query_dataset(query_dataset: List[dict]) -> List[Data]: + processed = [] + for sample in query_dataset: + original_data = sample["original_data"] + pred = sample["pred"] + node_mask = sample["node_mask"] + + if isinstance(pred, torch.Tensor): + pred = pred.item() + elif isinstance(pred, (list, np.ndarray)): + pred = pred[0] + + new_data = Data( + x=original_data.x, + edge_index=original_data.edge_index, + edge_attr=getattr(original_data, "edge_attr", None), + y=original_data.y, + target_pred=torch.tensor(pred, dtype=torch.long), + node_mask=node_mask, + ) + processed.append(new_data) + return processed + + +def convert_edge_scores_to_node_scores(edge_mask: torch.Tensor, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor: + node_scores = torch.zeros(num_nodes, device=edge_mask.device) + node_degrees = torch.zeros(num_nodes, device=edge_mask.device) + + for i in range(edge_index.shape[1]): + n1, n2 = edge_index[:, i] + imp = edge_mask[i] + node_scores[n1] += imp + node_scores[n2] += imp + node_degrees[n1] += 1 + node_degrees[n2] += 1 + + node_degrees[node_degrees == 0] = 1 + return node_scores / node_degrees + + +@dataclass +class EGStealConfig: + seed: int = 43 + gnn_backbone: str = "GIN" # GIN/GCN/GAT/GraphSAGE + gnn_layers: int = 3 + hidden_dim: int = 128 + gat_heads: int = 4 + + explanation_mode: str = "CAM" # CAM/GradCAM/Grad/GNNExplainer/PGExplainer + gnnexplainer_epochs: int = 100 + pgexplainer_epochs: int = 30 + + epochs: int = 50 # start smaller; set to 200 for full run + lr: float = 1e-3 + batch_size: int = 64 + + # query/training split ratios (match your data_preparation.py defaults) :contentReference[oaicite:2]{index=2} + target_ratio: float = 0.4 + target_val_ratio: float = 0.2 + test_ratio: float = 0.2 + shadow_ratio: float = 0.4 + + # surrogate alignment + augmentation (match your surrogate script defaults) :contentReference[oaicite:3]{index=3} + align_weight: float = 1.0 + augmentation_ratio: float = 0.2 + operation_ratio: float = 0.05 + augmentation_type: str = "combined" # drop_node/drop_edge/add_edge/combined + + +class EGStealAttack(BaseAttack): + supported_api_types = {"pyg"} + supported_datasets = set() # leave empty unless you want to restrict + + def __init__( + self, + dataset, + attack_node_fraction: float = None, + model_path: str = None, + device: Optional[str] = None, + config: Optional[dict] = None, + ): + super().__init__(dataset, attack_node_fraction=attack_node_fraction, model_path=model_path, device=device) + self.cfg = EGStealConfig(**(config or {})) + set_seed(self.cfg.seed) + + # ----------------------- + # Public API + # ----------------------- + def attack(self) -> Dict[str, float]: + device = self.device + + full_dataset = self._get_graph_dataset_list() + target_train, target_val, test_ds, shadow_ds = self._split_dataset(full_dataset) + + target_model = self._train_target_model(target_train, target_val) + query_dataset_shadow = self._query_target_model(target_model, shadow_ds) + query_dataset_test = self._query_target_model(target_model, test_ds) + + processed_shadow = process_query_dataset(query_dataset_shadow) + processed_test = process_query_dataset(query_dataset_test) + + surrogate_model = self._train_attack_model(processed_shadow, processed_test) + + # Evaluate + test_acc, test_auc, fidelity, rank_corr = self._evaluate_surrogate(surrogate_model, processed_test, target_model) + + return { + "test_acc": float(test_acc), + "test_auc": float(test_auc), + "fidelity_score": float(fidelity), + "rank_correlation": float(rank_corr), + } + + # ----------------------- + # Dataset helpers + # ----------------------- + def _get_graph_dataset_list(self): + # PyGIP’s Dataset stores data differently depending on dataset type. + # For TU-style graph classification, graph_dataset is usually a list-like dataset. + if getattr(self.dataset, "graph_dataset", None) is not None: + return self.dataset.graph_dataset + if getattr(self.dataset, "graph_data", None) is not None: + # If it's a single Data object, wrap it + gd = self.dataset.graph_data + return [gd] if isinstance(gd, Data) else gd + raise ValueError("Could not find dataset.graph_dataset or dataset.graph_data") + + def _split_dataset(self, dataset): + n = len(dataset) + target_num = int(n * self.cfg.target_ratio) + test_num = int(n * self.cfg.test_ratio) + shadow_num = n - target_num - test_num + + target_ds, test_ds, shadow_ds = random_split(dataset, [target_num, test_num, shadow_num]) + + target_train_num = int(target_num * (1 - self.cfg.target_val_ratio)) + target_val_num = target_num - target_train_num + target_train, target_val = random_split(target_ds, [target_train_num, target_val_num]) + + return target_train, target_val, test_ds, shadow_ds + + # ----------------------- + # Model builders + # ----------------------- + def _build_encoder(self, input_dim: int): + if self.cfg.gnn_backbone == "GIN": + return GIN(input_dim=input_dim, hidden_dim=self.cfg.hidden_dim, num_layers=self.cfg.gnn_layers) + if self.cfg.gnn_backbone == "GCN": + return GCN(input_dim=input_dim, hidden_dim=self.cfg.hidden_dim, num_layers=self.cfg.gnn_layers) + if self.cfg.gnn_backbone == "GAT": + return GAT(input_dim=input_dim, hidden_dim=self.cfg.hidden_dim, num_layers=self.cfg.gnn_layers, heads=self.cfg.gat_heads) + if self.cfg.gnn_backbone == "GraphSAGE": + return GraphSAGE(input_dim=input_dim, hidden_dim=self.cfg.hidden_dim, num_layers=self.cfg.gnn_layers) + raise ValueError(f"Unknown backbone {self.cfg.gnn_backbone}") + + def _build_classifier(self, num_classes: int): + return Classifier(input_dim=self.cfg.hidden_dim, output_dim=num_classes) + + # ----------------------- + # Target model training + # ----------------------- + def _train_target_model(self, train_ds, val_ds): + device = self.device + num_features = self.num_features + num_classes = self.num_classes + + encoder = self._build_encoder(num_features).to(device) + predictor = self._build_classifier(num_classes).to(device) + model = TargetModel(encoder=encoder, predictor=predictor, explanation_mode=self.cfg.explanation_mode).to(device) + + opt = Adam(model.parameters(), lr=self.cfg.lr) + criterion = torch.nn.CrossEntropyLoss() + + train_loader = DataLoader(train_ds, batch_size=self.cfg.batch_size, shuffle=True) + val_loader = DataLoader(val_ds, batch_size=self.cfg.batch_size, shuffle=False) + + best_auc = -math.inf + best_state = None + + for _ in range(self.cfg.epochs): + model.train() + for batch in train_loader: + batch = batch.to(device) + if batch.x is None: + batch.x = torch.ones((batch.num_nodes, 1), device=device) + + opt.zero_grad() + if self.cfg.explanation_mode in ["GNNExplainer", "PGExplainer"]: + out = model(batch.x, batch.edge_index, batch.batch) + else: + _, out = model(batch.x, batch.edge_index, batch.batch) + loss = criterion(out, batch.y) + loss.backward() + opt.step() + + # val AUC + acc, auc = self._eval_classifier(model, val_loader) + if not math.isnan(auc) and auc >= best_auc: + best_auc = auc + best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} + + if best_state is not None: + model.load_state_dict(best_state) + + return model + + def _eval_classifier(self, model, loader): + device = self.device + model.eval() + y_true, y_prob, y_pred = [], [], [] + + with torch.no_grad(): + for batch in loader: + batch = batch.to(device) + if batch.x is None: + batch.x = torch.ones((batch.num_nodes, 1), device=device) + + if self.cfg.explanation_mode in ["GNNExplainer", "PGExplainer"]: + out = model(batch.x, batch.edge_index, batch.batch) + else: + _, out = model(batch.x, batch.edge_index, batch.batch) + + probs = F.softmax(out, dim=1) + pred = out.argmax(dim=1) + + y_true.extend(batch.y.detach().cpu().tolist()) + y_pred.extend(pred.detach().cpu().tolist()) + + # binary AUC uses prob of class 1 + if probs.size(1) == 2: + y_prob.extend(probs[:, 1].detach().cpu().tolist()) + else: + # fallback: max prob (not ideal for true multiclass AUC) + y_prob.extend(probs.max(dim=1).values.detach().cpu().tolist()) + + acc = float(np.mean(np.array(y_pred) == np.array(y_true))) if len(y_true) else 0.0 + auc = float(safe_auc(y_true, y_prob)) if len(y_true) else 0.5 + return acc, auc + + # ----------------------- + # Querying + explanations + # ----------------------- + def _query_target_model(self, target_model, dataset_split) -> List[dict]: + device = self.device + target_model.eval() + loader = DataLoader(dataset_split, batch_size=min(self.cfg.batch_size, 256), shuffle=False) + + results = [] + + # setup explainers + cam = CAM(target_model) if self.cfg.explanation_mode == "CAM" else None + gradcam = GradCAM(target_model) if self.cfg.explanation_mode == "GradCAM" else None + grad = GradientExplainer(target_model) if self.cfg.explanation_mode == "Grad" else None + + gnnexplainer = None + pgexplainer = None + if self.cfg.explanation_mode == "GNNExplainer": + gnnexplainer = Explainer( + model=target_model, + algorithm=GNNExplainer(epochs=self.cfg.gnnexplainer_epochs), + explanation_type="phenomenon", + node_mask_type="attributes", + edge_mask_type="object", + model_config=ModelConfig( + mode=ModelMode.multiclass_classification, + task_level=ModelTaskLevel.graph, + return_type="raw", + ), + ) + + if self.cfg.explanation_mode == "PGExplainer": + pgexplainer = Explainer( + model=target_model, + algorithm=PGExplainer(epochs=self.cfg.pgexplainer_epochs), + explanation_type="phenomenon", + edge_mask_type="object", + model_config=ModelConfig( + mode=ModelMode.multiclass_classification, + task_level=ModelTaskLevel.graph, + return_type="raw", + ), + ) + + for batch_data in loader: + batch_data = batch_data.to(device) + if batch_data.x is None: + batch_data.x = torch.ones((batch_data.num_nodes, 1), device=device) + + batch = batch_data.batch + + with torch.no_grad(): + if self.cfg.explanation_mode in ["GNNExplainer", "PGExplainer"]: + out = target_model(batch_data.x, batch_data.edge_index, batch) + else: + _, out = target_model(batch_data.x, batch_data.edge_index, batch) + + preds = out.argmax(dim=1) + + # get node_mask aligned with nodes in the batch + if self.cfg.explanation_mode == "CAM": + # CAM needs a forward pass WITH hooks active (already happened above) + node_mask = cam.get_cam_scores(preds, batch) + elif self.cfg.explanation_mode == "GradCAM": + node_mask = gradcam.get_gradcam_scores(batch_data, preds) + elif self.cfg.explanation_mode == "Grad": + node_mask = grad.get_gradient_scores(batch_data, preds) + elif self.cfg.explanation_mode == "GNNExplainer": + exp = gnnexplainer(batch_data.x, batch_data.edge_index, batch=batch) + node_mask = exp.node_mask.view(-1) + elif self.cfg.explanation_mode == "PGExplainer": + # train pgexplainer quickly on this batch + for epoch in range(self.cfg.pgexplainer_epochs): + pgexplainer.algorithm = pgexplainer.algorithm.to(device) + _ = pgexplainer.algorithm.train(epoch, target_model, batch_data.x, batch_data.edge_index, target=preds, batch=batch) + exp = pgexplainer(batch_data.x, batch_data.edge_index, target=preds, batch=batch) + node_mask = convert_edge_scores_to_node_scores(exp.edge_mask, exp.edge_index, batch_data.x.size(0)) + else: + raise ValueError(f"Unknown explanation_mode {self.cfg.explanation_mode}") + + # split per-graph like your script :contentReference[oaicite:4]{index=4} + original_graphs = batch_data.to_data_list() + num_nodes_per_graph = batch_data.ptr[1:] - batch_data.ptr[:-1] + node_masks_list = torch.split(node_mask.detach().cpu(), num_nodes_per_graph.tolist()) + preds_list = preds.detach().cpu().tolist() + + for original_data, pred, nm in zip(original_graphs, preds_list, node_masks_list): + results.append( + { + "original_data": original_data.to("cpu"), + "pred": pred, + "node_mask": nm.to("cpu"), + } + ) + + return results + + # ----------------------- + # Surrogate training (attack model) + # ----------------------- + def _train_attack_model(self, processed_shadow: List[Data], processed_test: List[Data]): + device = self.device + num_features = self.num_features + num_classes = self.num_classes + + encoder = self._build_encoder(num_features).to(device) + predictor = self._build_classifier(num_classes).to(device) + model = SurrogateModel(encoder=encoder, predictor=predictor).to(device) + + criterion = torch.nn.CrossEntropyLoss() + ranknet = RankNetLoss().to(device) + opt = Adam(model.parameters(), lr=self.cfg.lr) + + augmentor = DataAugmentor() + + best_fidelity = -math.inf + best_state = None + + for _ in range(self.cfg.epochs): + # augmentation like your script :contentReference[oaicite:5]{index=5} + augmented = [] + if self.cfg.augmentation_ratio > 0: + num_aug = int(len(processed_shadow) * self.cfg.augmentation_ratio) + for i in np.random.choice(len(processed_shadow), size=num_aug, replace=False) if num_aug > 0 else []: + s = processed_shadow[int(i)] + if self.cfg.augmentation_type == "drop_node": + a = augmentor.drop_node(s, drop_ratio=self.cfg.operation_ratio) + elif self.cfg.augmentation_type == "drop_edge": + a = augmentor.drop_edge(s, drop_ratio=self.cfg.operation_ratio) + elif self.cfg.augmentation_type == "add_edge": + a = augmentor.add_edge(s, add_ratio=self.cfg.operation_ratio) + else: + a = augmentor.combined_augmentation( + s, + drop_node_ratio=self.cfg.operation_ratio, + drop_edge_ratio=self.cfg.operation_ratio, + add_edge_ratio=self.cfg.operation_ratio, + ) + if a is not None: + augmented.append(a) + + train_data = processed_shadow + augmented + train_loader = DataLoader(train_data, batch_size=self.cfg.batch_size, shuffle=True, collate_fn=custom_collate) + test_loader = DataLoader(processed_test, batch_size=self.cfg.batch_size, shuffle=False, collate_fn=custom_collate) + + model.train() + cam_surr = CAM(model) + + for batch in train_loader: + batch = batch.to(device) + if batch.x is None: + batch.x = torch.ones((batch.num_nodes, 1), device=device) + + opt.zero_grad() + node_emb, out = model(batch.x, batch.edge_index, batch.batch) + + # prediction loss against target_pred + loss_pred = criterion(out, batch.target_pred) + + # ranking alignment loss between surrogate CAM and true node_mask + preds = out.argmax(dim=1) + _ = out # forward already done; cam_surr hooks are populated + cam_scores = cam_surr.get_cam_scores(preds, batch.batch) + + true_mask = batch.node_mask.view(-1) + align = ranknet(cam_scores.view(-1), true_mask.to(device), batch.batch) + + loss = loss_pred + self.cfg.align_weight * align + loss.backward() + opt.step() + + # track best by fidelity on test + fidelity = self._fidelity(model, test_loader) + if fidelity >= best_fidelity: + best_fidelity = fidelity + best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} + + if best_state is not None: + model.load_state_dict(best_state) + return model + + # ----------------------- + # Metrics + # ----------------------- + def _fidelity(self, surrogate, loader): + device = self.device + surrogate.eval() + hits = 0 + total = 0 + with torch.no_grad(): + for batch in loader: + batch = batch.to(device) + if batch.x is None: + batch.x = torch.ones((batch.num_nodes, 1), device=device) + _, out = surrogate(batch.x, batch.edge_index, batch.batch) + pred = out.argmax(dim=1) + hits += int((pred == batch.target_pred).sum().item()) + total += int(batch.target_pred.size(0)) + return hits / total if total else 0.0 + + def _evaluate_surrogate(self, surrogate, processed_test, target_model): + device = self.device + loader = DataLoader(processed_test, batch_size=self.cfg.batch_size, shuffle=False, collate_fn=custom_collate) + + test_acc, test_auc = self._eval_classifier_like_surrogate(surrogate, loader) + fidelity = self._fidelity(surrogate, loader) + + # rank correlation: per-graph kendall tau averaged (approximate, matches intent) + try: + from scipy.stats import kendalltau + except Exception: + kendalltau = None + + surrogate.eval() + cam_surr = CAM(surrogate) + + taus = [] + with torch.no_grad(): + for batch in loader: + batch = batch.to(device) + if batch.x is None: + batch.x = torch.ones((batch.num_nodes, 1), device=device) + + _, out = surrogate(batch.x, batch.edge_index, batch.batch) + preds = out.argmax(dim=1) + cam_scores = cam_surr.get_cam_scores(preds, batch.batch).detach().cpu() + true_scores = batch.node_mask.view(-1).detach().cpu() + batch_ids = batch.batch.detach().cpu() + + if kendalltau is None: + continue + + for gid in torch.unique(batch_ids): + m = batch_ids == gid + p = cam_scores[m].numpy() + t = true_scores[m].numpy() + if len(p) < 2: + continue + tau = kendalltau(p, t).correlation + if tau is not None and not np.isnan(tau): + taus.append(float(tau)) + + rank_corr = float(np.mean(taus)) if taus else 0.0 + return test_acc, test_auc, fidelity, rank_corr + + def _eval_classifier_like_surrogate(self, surrogate, loader): + device = self.device + surrogate.eval() + y_true, y_prob, y_pred = [], [], [] + + with torch.no_grad(): + for batch in loader: + batch = batch.to(device) + if batch.x is None: + batch.x = torch.ones((batch.num_nodes, 1), device=device) + + _, out = surrogate(batch.x, batch.edge_index, batch.batch) + probs = F.softmax(out, dim=1) + pred = out.argmax(dim=1) + + # Use original y for "accuracy/auc" like your surrogate script output + y_true.extend(batch.y.detach().cpu().tolist()) + y_pred.extend(pred.detach().cpu().tolist()) + if probs.size(1) == 2: + y_prob.extend(probs[:, 1].detach().cpu().tolist()) + else: + y_prob.extend(probs.max(dim=1).values.detach().cpu().tolist()) + + acc = float(np.mean(np.array(y_pred) == np.array(y_true))) if len(y_true) else 0.0 + auc = float(safe_auc(y_true, y_prob)) if len(y_true) else 0.5 + return acc, auc + diff --git a/pygip/models/attack/egsteal/eg_models.py b/pygip/models/attack/egsteal/eg_models.py new file mode 100644 index 0000000..a1bf5fd --- /dev/null +++ b/pygip/models/attack/egsteal/eg_models.py @@ -0,0 +1,259 @@ +import torch +import torch.nn.functional as F +from torch import nn +from torch_geometric.nn import GCNConv, GATConv, GINConv, SAGEConv, global_mean_pool + + +class SurrogateModel(nn.Module): + def __init__(self, encoder, predictor): + super().__init__() + self.encoder = encoder + self.predictor = predictor + + def forward(self, x, edge_index, batch): + node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) + out = self.predictor(graph_embeddings) + return node_embeddings, out + + +class TargetModel(nn.Module): + def __init__(self, encoder, predictor, explanation_mode: str): + super().__init__() + self.encoder = encoder + self.predictor = predictor + self.explanation_mode = explanation_mode + + def forward(self, x, edge_index, batch): + node_embeddings, graph_embeddings = self.encoder(x, edge_index, batch) + out = self.predictor(graph_embeddings) + if self.explanation_mode in ["GNNExplainer", "PGExplainer"]: + return out + return node_embeddings, out + + +class GCN(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super().__init__() + self.convs = nn.ModuleList() + for i in range(num_layers): + if i == 0: + self.convs.append(GCNConv(input_dim, hidden_dim)) + else: + self.convs.append(GCNConv(hidden_dim, hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = F.relu(conv(x, edge_index)) + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + return node_embeddings, graph_embeddings + + +class GraphSAGE(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super().__init__() + self.convs = nn.ModuleList() + for i in range(num_layers): + if i == 0: + self.convs.append(SAGEConv(input_dim, hidden_dim)) + else: + self.convs.append(SAGEConv(hidden_dim, hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = F.relu(conv(x, edge_index)) + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + return node_embeddings, graph_embeddings + + +class GAT(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers, heads=4): + super().__init__() + self.convs = nn.ModuleList() + for i in range(num_layers): + if i == 0: + self.convs.append(GATConv(input_dim, hidden_dim // heads, heads=heads)) + else: + self.convs.append(GATConv(hidden_dim, hidden_dim // heads, heads=heads)) + + def forward(self, x, edge_index, batch): + for conv in self.convs: + x = F.relu(conv(x, edge_index)) + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + return node_embeddings, graph_embeddings + + +class GIN(nn.Module): + def __init__(self, input_dim, hidden_dim, num_layers): + super().__init__() + self.convs = nn.ModuleList() + self.bns = nn.ModuleList() + for i in range(num_layers): + if i == 0: + nn_seq = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + else: + nn_seq = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + self.convs.append(GINConv(nn_seq)) + self.bns.append(nn.BatchNorm1d(hidden_dim)) + + def forward(self, x, edge_index, batch): + for conv, bn in zip(self.convs, self.bns): + x = F.relu(bn(conv(x, edge_index))) + node_embeddings = x + graph_embeddings = global_mean_pool(x, batch) + return node_embeddings, graph_embeddings + + +class Classifier(nn.Module): + def __init__(self, input_dim, output_dim): + super().__init__() + self.fc = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.fc(x) + + +class CAM: + def __init__(self, model): + self.model = model + self.activations = None + self.classifier_weights = None + self._register_hooks() + + def _register_hooks(self): + def forward_hook(module, input, output): + self.activations = output + + classifier = self.model.predictor + if isinstance(classifier, nn.Linear): + self.classifier_weights = classifier.weight + elif hasattr(classifier, "fc") and isinstance(classifier.fc, nn.Linear): + self.classifier_weights = classifier.fc.weight + else: + raise ValueError("Classifier must be nn.Linear or have .fc as nn.Linear") + + last_conv = self.model.encoder.convs[-1] + last_conv.register_forward_hook(forward_hook) + + def get_cam_scores(self, target_classes, batch_ids): + if self.activations is None: + raise ValueError("No activations recorded. Run a forward pass first.") + + num_graphs = batch_ids.max().item() + 1 + cam_scores = [] + for graph_id in range(num_graphs): + cls = target_classes[graph_id].item() + weight = self.classifier_weights[cls] + node_indices = (batch_ids == graph_id).nonzero(as_tuple=False).squeeze() + if node_indices.numel() == 0: + cam = torch.tensor([], device=self.activations.device) + else: + activation = self.activations[node_indices] + cam = torch.matmul(activation, weight) + if cam.dim() == 0: + cam = cam.unsqueeze(0) + cam_scores.append(cam) + + return torch.cat(cam_scores, dim=0) + + +class GradCAM: + def __init__(self, model): + self.model = model + self.activations = None + self._register_hooks() + + def _register_hooks(self): + def forward_hook(module, input, output): + self.activations = output + + last_conv = self.model.encoder.convs[-1] + last_conv.register_forward_hook(forward_hook) + + def compute_weights(self, gradients, batch): + num_graphs = batch.max().item() + 1 + weights = [] + for graph_id in range(num_graphs): + mask = batch == graph_id + graph_grads = gradients[mask] + alpha = graph_grads.mean(dim=0) + weights.append(alpha) + return torch.stack(weights) + + def get_gradcam_scores(self, input_data, target_classes): + self.model.zero_grad() + _, output = self.model(input_data.x, input_data.edge_index, input_data.batch) + + batch_size = len(target_classes) + scores = output[range(batch_size), target_classes] + + gradients = torch.autograd.grad( + scores, + self.activations, + grad_outputs=torch.ones_like(scores), + retain_graph=True, + )[0] + + if self.activations is None: + raise ValueError("Activations were not captured") + + weights = self.compute_weights(gradients, input_data.batch) + + batch = input_data.batch + num_graphs = batch.max().item() + 1 + gradcam_scores = [] + for graph_id in range(num_graphs): + mask = batch == graph_id + if mask.any(): + curr_activations = self.activations[mask] + curr_weights = weights[graph_id] + gradcam = torch.matmul(curr_activations, curr_weights) + gradcam = F.relu(gradcam) + gradcam_scores.append(gradcam) + + return torch.cat(gradcam_scores) + + +class GradientExplainer: + def __init__(self, model): + self.model = model + + def get_gradient_scores(self, input_data, target_classes): + input_data.x.requires_grad = True + self.model.zero_grad() + + _, output = self.model(input_data.x, input_data.edge_index, input_data.batch) + + batch = input_data.batch + num_graphs = batch.max().item() + 1 + scores_all = [] + + for graph_id in range(num_graphs): + mask = batch == graph_id + target_class = target_classes[graph_id] + score = output[graph_id, target_class] + + grad = torch.autograd.grad( + score, + input_data.x, + retain_graph=True, + create_graph=False, + )[0] + + curr_grad = grad[mask] + relu_grad = F.relu(curr_grad) + scores = torch.norm(relu_grad, p=2, dim=1) + scores_all.append(scores) + + return torch.cat(scores_all) + diff --git a/pygip/models/attack/egsteal/eg_utils.py b/pygip/models/attack/egsteal/eg_utils.py new file mode 100644 index 0000000..70b90c6 --- /dev/null +++ b/pygip/models/attack/egsteal/eg_utils.py @@ -0,0 +1,345 @@ +import os +import random +import logging +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score +from torch import Tensor +from torch.nn import ReLU, Sequential +from torch_geometric.data import Data +from torch_geometric.explain import Explanation +from torch_geometric.explain.algorithm import ExplainerAlgorithm +from torch_geometric.explain.algorithm.utils import clear_masks, set_masks +from torch_geometric.explain.config import ExplanationType, ModelMode, ModelTaskLevel +from torch_geometric.nn import Linear +from torch_geometric.nn.inits import reset +from torch_geometric.utils import get_embeddings, subgraph + + +def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + + +def safe_auc(y_true, y_pred) -> float: + y_true = np.array(y_true) + y_pred = np.array(y_pred) + if len(np.unique(y_true)) == 1: + return 0.5 + if len(np.unique(y_pred)) == 1: + return 0.5 + return roc_auc_score(y_true, y_pred) + + +class RankNetLoss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, pred_scores, true_scores, batch_ids): + unique_batch = torch.unique(batch_ids) + total_loss = 0.0 + count = 0 + + for b in unique_batch: + mask = batch_ids == b + p = pred_scores[mask] + t = true_scores[mask] + num_nodes = p.size(0) + if num_nodes < 2: + continue + + indices_i, indices_j = torch.triu_indices( + num_nodes, num_nodes, offset=1, device=p.device + ) + + s_i = t[indices_i] + s_j = t[indices_j] + p_i = p[indices_i] + p_j = p[indices_j] + + y_ij = torch.zeros_like(s_i, dtype=torch.float, device=p.device) + y_ij[s_i > s_j] = 1.0 + y_ij[s_i == s_j] = 0.5 + + sigmoid_diff = torch.sigmoid(p_i - p_j) + loss = F.binary_cross_entropy(sigmoid_diff, y_ij, reduction="mean") + + total_loss += loss + count += 1 + + if count == 0: + return torch.tensor(0.0, device=pred_scores.device, requires_grad=True) + return total_loss / count + + +class DataAugmentor: + def drop_node(self, sample: Data, drop_ratio=0.1) -> Optional[Data]: + node_mask = ( + sample.node_mask + if hasattr(sample, "node_mask") + else torch.ones(sample.num_nodes, dtype=torch.float, device=sample.x.device) + ) + num_nodes = sample.num_nodes + num_drop = int(drop_ratio * num_nodes) + if num_drop == 0 or num_drop >= num_nodes: + return None + + _, drop_indices = torch.topk(node_mask, k=num_drop, largest=False) + keep_mask = torch.ones(num_nodes, dtype=torch.bool, device=sample.x.device) + keep_mask[drop_indices] = False + keep_indices = torch.nonzero(keep_mask, as_tuple=False).squeeze() + + new_edge_index, edge_attr, _ = subgraph( + keep_indices, + sample.edge_index, + edge_attr=sample.edge_attr if hasattr(sample, "edge_attr") else None, + relabel_nodes=True, + num_nodes=num_nodes, + return_edge_mask=True, + ) + + new_x = sample.x[keep_indices] + new_node_mask = node_mask[keep_indices] + + return Data( + x=new_x, + edge_index=new_edge_index, + edge_attr=edge_attr, + y=sample.y, + target_pred=sample.target_pred if hasattr(sample, "target_pred") else None, + node_mask=new_node_mask, + ) + + def drop_edge(self, sample: Data, drop_ratio=0.1) -> Optional[Data]: + node_mask = ( + sample.node_mask + if hasattr(sample, "node_mask") + else torch.ones(sample.num_nodes, dtype=torch.float, device=sample.x.device) + ) + num_nodes = sample.num_nodes + num_low = int(drop_ratio * num_nodes) + if num_low < 2: + return None + + _, indices = torch.topk(node_mask, k=num_low, largest=False) + low_nodes = set(indices.cpu().tolist()) + + src_nodes = sample.edge_index[0] + dst_nodes = sample.edge_index[1] + edge_keep = torch.ones(sample.edge_index.size(1), dtype=torch.bool, device=sample.x.device) + + for node in low_nodes: + mask = (src_nodes == node) | (dst_nodes == node) + edge_keep = edge_keep & ~mask + + new_edge_index = sample.edge_index[:, edge_keep] + edge_attr = sample.edge_attr[edge_keep] if hasattr(sample, "edge_attr") and sample.edge_attr is not None else None + + return Data( + x=sample.x, + edge_index=new_edge_index, + edge_attr=edge_attr, + y=sample.y, + target_pred=sample.target_pred if hasattr(sample, "target_pred") else None, + node_mask=node_mask, + ) + + def add_edge(self, sample: Data, add_ratio=0.1) -> Optional[Data]: + node_mask = ( + sample.node_mask + if hasattr(sample, "node_mask") + else torch.ones(sample.num_nodes, dtype=torch.float, device=sample.x.device) + ) + num_nodes = sample.num_nodes + num_low = int(add_ratio * num_nodes) + if num_low < 2: + return None + + _, indices = torch.topk(node_mask, k=num_low, largest=False) + low_nodes = indices.tolist() + + new_edges = [] + for i in range(len(low_nodes)): + for j in range(i + 1, len(low_nodes)): + u, v = low_nodes[i], low_nodes[j] + new_edges.append([u, v]) + new_edges.append([v, u]) + + if not new_edges: + return None + + new_edges_tensor = torch.tensor(new_edges, dtype=torch.long, device=sample.x.device).t() + new_edge_index = torch.cat([sample.edge_index, new_edges_tensor], dim=1) + + unique_edges, unique_idx = torch.unique(new_edge_index.t(), dim=0, return_inverse=True) + new_edge_index = unique_edges.t() + + if hasattr(sample, "edge_attr") and sample.edge_attr is not None: + num_new_edges = new_edge_index.size(1) - sample.edge_index.size(1) + if num_new_edges > 0: + default_attr = torch.ones((num_new_edges, sample.edge_attr.size(1)), + dtype=sample.edge_attr.dtype, device=sample.x.device) + edge_attr = torch.cat([sample.edge_attr, default_attr], dim=0) + else: + edge_attr = sample.edge_attr + edge_attr = edge_attr[unique_idx] + else: + edge_attr = None + + return Data( + x=sample.x, + edge_index=new_edge_index, + edge_attr=edge_attr, + y=sample.y, + target_pred=sample.target_pred if hasattr(sample, "target_pred") else None, + node_mask=node_mask, + ) + + def combined_augmentation(self, sample: Data, + drop_node_ratio=0.1, drop_edge_ratio=0.1, add_edge_ratio=0.1) -> Optional[Data]: + chosen = random.choice(["drop_node", "drop_edge", "add_edge"]) + try: + if chosen == "drop_node": + return self.drop_node(sample, drop_ratio=drop_node_ratio) + if chosen == "drop_edge": + return self.drop_edge(sample, drop_ratio=drop_edge_ratio) + return self.add_edge(sample, add_ratio=add_edge_ratio) + except Exception: + return None + + +class PGExplainer(ExplainerAlgorithm): + coeffs = { + "edge_size": 0.05, + "edge_ent": 1.0, + "temp": [5.0, 2.0], + "bias": 0.01, + } + + def __init__(self, epochs: int, lr: float = 0.003, **kwargs): + super().__init__() + self.epochs = epochs + self.lr = lr + self.coeffs.update(kwargs) + + self.mlp = Sequential( + Linear(-1, 64), + ReLU(), + Linear(64, 1), + ) + self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr) + self._curr_epoch = -1 + + def reset_parameters(self): + reset(self.mlp) + + def train(self, epoch: int, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, + target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs): + if isinstance(x, dict) or isinstance(edge_index, dict): + raise ValueError("Heterogeneous graphs not supported") + + if self.model_config.task_level == ModelTaskLevel.node: + if index is None: + raise ValueError("index required for node-level explanations") + if isinstance(index, Tensor) and index.numel() > 1: + raise ValueError("index must be scalar") + + z = get_embeddings(model, x, edge_index, **kwargs)[-1] + self.optimizer.zero_grad() + temperature = self._get_temperature(epoch) + + inputs = self._get_inputs(z, edge_index, index) + logits = self.mlp(inputs).view(-1) + edge_mask = self._concrete_sample(logits, temperature) + set_masks(model, edge_mask, edge_index, apply_sigmoid=True) + + if self.model_config.task_level == ModelTaskLevel.node: + _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, num_nodes=x.size(0)) + edge_mask = edge_mask[hard_edge_mask] + + y_hat, y = model(x, edge_index, **kwargs), target + if index is not None: + y_hat, y = y_hat[index], y[index] + + loss = self._loss(y_hat, y, edge_mask) + loss.backward() + self.optimizer.step() + + clear_masks(model) + self._curr_epoch = epoch + return float(loss) + + def forward(self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, + target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) -> Explanation: + if self._curr_epoch < self.epochs - 1: + raise ValueError("PGExplainer not fully trained yet") + + hard_edge_mask = None + if self.model_config.task_level == ModelTaskLevel.node: + if index is None: + raise ValueError("index required for node-level explanations") + if isinstance(index, Tensor) and index.numel() > 1: + raise ValueError("index must be scalar") + _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, num_nodes=x.size(0)) + + z = get_embeddings(model, x, edge_index, **kwargs)[-1] + inputs = self._get_inputs(z, edge_index, index) + logits = self.mlp(inputs).view(-1) + edge_mask = self._post_process_mask(logits, hard_edge_mask, apply_sigmoid=True) + return Explanation(edge_mask=edge_mask) + + def supports(self) -> bool: + if self.explainer_config.explanation_type != ExplanationType.phenomenon: + logging.error("PGExplainer only supports phenomenon explanations") + return False + if self.model_config.task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}: + logging.error("PGExplainer supports node-level or graph-level only") + return False + if self.explainer_config.node_mask_type is not None: + logging.error("PGExplainer does not support feature masks") + return False + return True + + def _get_inputs(self, embedding: Tensor, edge_index: Tensor, index: Optional[int] = None) -> Tensor: + zs = [embedding[edge_index[0]], embedding[edge_index[1]]] + if self.model_config.task_level == ModelTaskLevel.node: + assert index is not None + zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1)) + return torch.cat(zs, dim=-1) + + def _get_temperature(self, epoch: int) -> float: + temp = self.coeffs["temp"] + return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs) + + def _concrete_sample(self, logits: Tensor, temperature: float = 1.0) -> Tensor: + bias = self.coeffs["bias"] + eps = (1 - 2 * bias) * torch.rand_like(logits) + bias + return (eps.log() - (1 - eps).log() + logits) / temperature + + def _loss(self, y_hat: Tensor, y: Tensor, edge_mask: Tensor) -> Tensor: + if self.model_config.mode == ModelMode.binary_classification: + loss_fn = nn.CrossEntropyLoss() + loss = loss_fn(y_hat, y) + elif self.model_config.mode == ModelMode.multiclass_classification: + loss = self._loss_multiclass_classification(y_hat, y) + else: + loss = self._loss_regression(y_hat, y) + + mask = edge_mask.sigmoid() + size_loss = mask.sum() * self.coeffs["edge_size"] + mask = 0.99 * mask + 0.005 + mask_ent = -mask * mask.log() - (1 - mask) * (1 - mask).log() + mask_ent_loss = mask_ent.mean() * self.coeffs["edge_ent"] + return loss + size_loss + mask_ent_loss +