diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 1a9048e..0000000 --- a/.gitignore +++ /dev/null @@ -1,38 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# PyInstaller -# Usually contains a build/ and dist/ folder -*.manifest -*.spec - -# Pytype -.pytype/ - -# Cython debug symbols -cython_debug/ - -# VS Code -.vscode/ - -# JetBrains IDEs -.idea/ - -# macOS system files -.DS_Store -.AppleDouble -.LSOverride - -# Icon must end with two \r -Icon - -# Thumbnails -._* - -# macOS network shares -.AppleSharePDS - -#virtual environments folder -.venv diff --git a/datasets/__pycache__/__init__.cpython-311.pyc b/datasets/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..1b2962d Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-311.pyc differ diff --git a/datasets/__pycache__/datasets.cpython-311.pyc b/datasets/__pycache__/datasets.cpython-311.pyc new file mode 100644 index 0000000..ef32b4b Binary files /dev/null and b/datasets/__pycache__/datasets.cpython-311.pyc differ diff --git a/models/defense/GNN_Fingers.py b/models/defense/GNN_Fingers.py new file mode 100644 index 0000000..253c656 --- /dev/null +++ b/models/defense/GNN_Fingers.py @@ -0,0 +1,990 @@ +import importlib +import numpy as np +from tqdm import tqdm +import copy +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from datasets import dataset +#from models.nn import GraphSAGE +#from dgl.dataloading import NeighborSampler, NodeCollator +from torch_geometric.nn import GCNConv, GATConv +from torch_geometric.utils import erdos_renyi_graph, to_dense_adj, dense_to_sparse +#from torch_geometric.loader import NeighborLoader, DataLoader +from torch_geometric.data import Data as PyGData +from models.defense.base import BaseDefense + + +class Univerifier(nn.Module): + """ + Unified Verification Mechanism - Binary classifier that takes concatenated outputs + from suspect models and predicts whether they are pirated or irrelevant. + """ + def __init__(self, input_dim, hidden_dims=[128, 64,32,16,8,4]): + super(Univerifier, self).__init__() + layers = [] + prev_dim = input_dim + + for hidden_dim in hidden_dims: + layers.append(nn.Linear(prev_dim, hidden_dim)) + layers.append(nn.LeakyReLU()) + layers.append(nn.Dropout(0.1)) + prev_dim = hidden_dim + + layers.append(nn.Linear(prev_dim, 2)) + self.classifier = nn.Sequential(*layers) + + def forward(self, x): + return F.softmax(self.classifier(x), dim=1) + +class GNNFingers(BaseDefense): + """ + GNNFingers: A Fingerprinting Framework for Verifying Ownerships of Graph Neural Networks + Implementation based on the paper by You et al. (2024) + """ + supported_api_types = {"dgl"} + + def __init__(self, dataset, attack_node_fraction=0.2, device=None, attack_name=None, + num_fingerprints=64, fingerprint_nodes=32, lambda_threshold=0.7, + batch_size=32, num_neighbors=[5, 5]): + """ + Initialize GNNFingers defense framework + + Parameters + ---------- + dataset : Dataset + The original dataset containing the graph to defend + attack_node_fraction : float + Fraction of nodes to consider for attack + device : torch.device + Device to run computations on + attack_name : str + Name of the attack class to use + num_fingerprints : int + Number of graph fingerprints to generatea + fingerprint_nodes : int + Number of nodes in each fingerprint graph + lambda_threshold : float + Threshold for Univerifier classification + batch_size : int + Batch size for training + num_neighbors : list + Number of neighbors to sample at each layer + """ + super().__init__(dataset, attack_node_fraction, device) + self.attack_name = attack_name or "ModelExtractionAttack0" + self.dataset = dataset + self.graph = dataset.graph_data + + # Extract dataset properties + self.node_number = dataset.num_nodes + self.feature_number = dataset.num_features + self.label_number = dataset.num_classes + self.attack_node_number = int(self.node_number * attack_node_fraction) + + # Training parameters + # self.batch_size = batch_size + # self.num_neighbors = num_neighbors + + # Convert DGL to PyG data + self.pyg_data = self._dgl_to_pyg(self.graph) + + # Extract features and labels + self.features = self.pyg_data.x + self.labels = self.pyg_data.y + + # Extract masks + self.train_mask = self.pyg_data.train_mask + self.test_mask = self.pyg_data.test_mask + + # GNNFingers parameters + self.num_fingerprints = num_fingerprints + self.fingerprint_nodes = fingerprint_nodes + self.lambda_threshold = lambda_threshold + + # Initialize components + self.target_gnn = None + self.positive_gnns = [] # Pirated GNNs + self.negative_gnns = [] # Irrelevant GNNs + self.graph_fingerprints = None + self.univerifier = None + + # Move tensors to device + if self.device != 'cpu': + self.graph = self.graph.to(self.device) + self.features = self.features.to(self.device) + self.labels = self.labels.to(self.device) + self.train_mask = self.train_mask.to(self.device) + self.test_mask = self.test_mask.to(self.device) + + def _dgl_to_pyg(self, dgl_graph): + + """Convert DGL graph to PyTorch Geometric Data object""" + # Extract edge indices + edge_index = torch.stack(dgl_graph.edges()) + x = dgl_graph.ndata.get('feat') + y = dgl_graph.ndata.get('label') + + train_mask = dgl_graph.ndata.get('train_mask') + val_mask = dgl_graph.ndata.get('val_mask') + test_mask = dgl_graph.ndata.get('test_mask') + + data = PyGData(x=x, edge_index=edge_index, y=y, + train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) + + return data + + + # def _create_dataloaders(self, graph_data): + # """Create train and test dataloaders with neighbor sampling""" + # # For DGL graphs + # if hasattr(graph_data, 'ndata'): + # # DGL graph + # sampler = NeighborSampler(self.num_neighbors) + # train_nids = graph_data.ndata['train_mask'].nonzero(as_tuple=True)[0].to(self.device) + # test_nids = graph_data.ndata['test_mask'].nonzero(as_tuple=True)[0].to(self.device) + + # train_collator = NodeCollator(graph_data, train_nids, sampler) + # test_collator = NodeCollator(graph_data, test_nids, sampler) + + # train_dataloader = DataLoader( + # train_collator.dataset, + # batch_size=self.batch_size, + # shuffle=True, + # collate_fn=train_collator.collate, + # drop_last=False + # ) + + # test_dataloader = DataLoader( + # test_collator.dataset, + # batch_size=self.batch_size, + # shuffle=False, + # collate_fn=test_collator.collate, + # drop_last=False + # ) + + # return train_dataloader, test_dataloader + + # else: + # # PyG data + # train_loader = NeighborLoader( + # graph_data, + # num_neighbors=self.num_neighbors, + # batch_size=self.batch_size, + # shuffle=True, + # input_nodes=graph_data.train_mask, + # ) + + # test_loader = NeighborLoader( + # graph_data, + # num_neighbors=self.num_neighbors, + # batch_size=self.batch_size, + # shuffle=False, + # input_nodes=graph_data.test_mask, + # ) + + # return train_loader, test_loader + + def _get_attack_class(self, attack_name): + """Dynamically import and return the specified attack class""" + try: + attack_module = importlib.import_module('models.attack') + attack_class = getattr(attack_module, attack_name) + return attack_class + except (ImportError, AttributeError) as e: + print(f"Error loading attack class '{attack_name}': {e}") + print("Falling back to ModelExtractionAttack0") + attack_module = importlib.import_module('models.attack') + return getattr(attack_module, "ModelExtractionAttack0") + + def defend(self, attack_name=None): + """ + Main defense workflow for GNNFingers + """ + attack_name = attack_name or self.attack_name + AttackClass = self._get_attack_class(attack_name) + print(f"Using attack method: {attack_name}") + + # Step 1: Train target model + print("Training target GNN...") + self.target_gnn = self._train_gnn_model(self.pyg_data, "Target GNN") + + # Step 2: Prepare positive and negative GNNs + print("Preparing positive (pirated) GNNs...") + self.positive_gnns = self._prepare_positive_gnns(self.target_gnn, num_models=50) + + print("Preparing negative (irrelevant) GNNs...") + self.negative_gnns = self._prepare_negative_gnns(num_models=50) + + # Step 3: Initialize graph fingerprints + print("Initializing graph fingerprints...") + self.graph_fingerprints = self._initialize_graph_fingerprints() + + # Step 4: Initialize Univerifier + output_dim = self._get_output_dimension(self.target_gnn, self.graph_fingerprints[0]) + self.univerifier = Univerifier(input_dim=output_dim * self.num_fingerprints) + self.univerifier = self.univerifier.to(self.device) + + # Step 5: Joint learning of fingerprints and Univerifier + print("Joint learning of fingerprints and Univerifier...") + self._joint_learning() + + # Step 6: Attack target model + print("Attacking target model...") + attack = AttackClass(self.dataset, attack_node_fraction=0.2) + attack_results = attack.attack() + suspect_model = attack.net2 if hasattr(attack, 'net2') else None + + # Step 7: Verify ownership + if suspect_model is not None: + print("Verifying ownership of suspect model...") + verification_result = self._verify_ownership(suspect_model) + print(f"Ownership verification result: {verification_result}") + + return { + "attack_results": attack_results, + "verification_result": verification_result, + "target_accuracy": self._evaluate_model(self.target_gnn, self.pyg_data), + "suspect_accuracy": self._evaluate_model(suspect_model, self.pyg_data) + } + + return {"attack_results": attack_results, "verification_result": "No suspect model found"} + + def _train_gnn_model(self, data, model_name="GNN", epochs=100): + """Train a GNN model on the given data""" + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=128, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + # Forward pass + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + # Backward pass + loss.backward() + optimizer.step() + + # Evaluate ONLY on test nodes + if epoch % 10 == 0: + model.eval() + with torch.no_grad(): + out = model(data.x, data.edge_index) + pred = out[data.test_mask].argmax(dim=1) + correct = (pred == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + acc = correct / total if total > 0 else 0 + + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + print(f"{model_name} trained with accuracy: {best_acc:.4f}") + return best_model + + def _prepare_positive_gnns(self, target_model, num_models=50): + """Prepare pirated GNNs using obfuscation techniques""" + positive_models = [] + + for i in range(num_models): + # Apply different obfuscation techniques + if i % 3 == 0: + # Fine-tuning - fine-tune different numbers of layers + layers_to_finetune = random.randint(1, 3) + model = self._fine_tune_model(copy.deepcopy(target_model), self.pyg_data, + epochs=10, num_layers_to_finetune=layers_to_finetune) + elif i % 3 == 1: + # Partial retraining - retrain different numbers of layers + layers_to_retrain = random.randint(1, 3) + model = self._partial_retrain_model(copy.deepcopy(target_model), self.pyg_data, + epochs=15, num_layers_to_retrain=layers_to_retrain) + else: + # Distillation - use different temperatures and architectures + temperature = random.choice([1.5, 2.0, 3.0, 4.0]) + model = self._distill_model(target_model, self.pyg_data, + epochs=30, temperature=temperature) + + positive_models.append(model) + + return positive_models + + + def _prepare_negative_gnns(self, num_models=50): + """Prepare irrelevant GNNs""" + negative_models = [] + + for i in range(num_models): + # Train from scratch with different architectures or data + if i % 2 == 0: + # Different architecture + model = self._train_different_architecture(self.pyg_data) + else: + # Different training data (subset) + model = self._train_on_subset(self.pyg_data) + + negative_models.append(model) + + return negative_models + + def _fine_tune_model(self, model, data, epochs=10, num_layers_to_finetune=1): + """Fine-tune a model on the same data, but only update the last K layers""" + # Freeze all layers initially + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze the last K layers for fine-tuning + if hasattr(model, 'convs'): + # For models with convs attribute (like GCNConvGNN, GATConvGNN) + total_layers = len(model.convs) + layers_to_finetune = min(num_layers_to_finetune, total_layers) + + for i in range(total_layers - layers_to_finetune, total_layers): + for param in model.convs[i].parameters(): + param.requires_grad = True + else: + # For other model types, try to find the last layers + all_params = list(model.parameters()) + layers_to_finetune = min(num_layers_to_finetune, len(all_params)) + + for param in all_params[-layers_to_finetune:]: + param.requires_grad = True + + # Only optimize parameters that require gradients + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = optim.Adam(trainable_params, lr=0.001) + criterion = nn.CrossEntropyLoss() + + # Count how many parameters are being fine-tuned + num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + num_total = sum(p.numel() for p in model.parameters()) + print(f"Fine-tuning {num_trainable}/{num_total} parameters ({num_trainable/num_total*100:.1f}%)") + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + loss.backward() + optimizer.step() + + # Unfreeze all parameters for future use + for param in model.parameters(): + param.requires_grad = True + + return model + + def _partial_retrain_model(self, model, data, epochs=10, num_layers_to_retrain=2): + """Partially retrain a model with random initialization of K layers before resuming training""" + # Randomly initialize selected K layers + if hasattr(model, 'convs'): + # For models with convs attribute (like GCNConvGNN, GATConvGNN) + total_layers = len(model.convs) + layers_to_retrain = min(num_layers_to_retrain, total_layers) + + # Randomly select K layers to retrain (not necessarily the last ones) + layer_indices = random.sample(range(total_layers), layers_to_retrain) + + print(f"Partially retraining layers: {layer_indices}") + + for idx in layer_indices: + model.convs[idx].reset_parameters() # Random reinitialization + else: + # For other model types, try to find and reset random layers + # This is a fallback approach + all_layers = list(model.children()) + total_layers = len(all_layers) + layers_to_retrain = min(num_layers_to_retrain, total_layers) + + layer_indices = random.sample(range(total_layers), layers_to_retrain) + + for idx in layer_indices: + if hasattr(all_layers[idx], 'reset_parameters'): + all_layers[idx].reset_parameters() + + # Train the entire model (both retrained and original layers) + optimizer = optim.Adam(model.parameters(), lr=0.01) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + loss.backward() + optimizer.step() + + # Track best model + if epoch % 5 == 0: + acc = self._evaluate_model(model, data) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + print(f"Partial retraining completed. Best accuracy: {best_acc:.4f}") + return best_model if best_model is not None else model + def _distill_model(self, teacher_model, data, epochs=30, temperature=2.0): + """Distill knowledge from teacher to student model with different architecture""" + # Create student model with different architecture + if isinstance(teacher_model, GCNConvGNN): + # If teacher is GCN, use GAT as student + student_model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=96, # Different hidden size + out_channels=self.label_number, + num_layers=2, # Different number of layers + heads=3 # Different number of heads + ).to(self.device) + else: + # If teacher is GAT or other, use GCN as student + student_model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, # Different hidden size + out_channels=self.label_number, + num_layers=3 # Different number of layers + ).to(self.device) + + optimizer = optim.Adam(student_model.parameters(), lr=0.01, weight_decay=1e-4) + + # Combined loss: KL divergence for distillation + cross entropy for ground truth + kl_loss = nn.KLDivLoss(reduction='batchmean') + ce_loss = nn.CrossEntropyLoss() + + teacher_model.eval() + + best_acc = 0 + best_student = None + + for epoch in range(epochs): + student_model.train() + optimizer.zero_grad() + + # Get teacher predictions (with temperature scaling) + with torch.no_grad(): + teacher_logits = teacher_model(data.x, data.edge_index) + teacher_probs = F.softmax(teacher_logits / temperature, dim=1) + + # Get student predictions + student_logits = student_model(data.x, data.edge_index) + student_log_probs = F.log_softmax(student_logits / temperature, dim=1) + + # Distillation loss (KL divergence between teacher and student) + distill_loss = kl_loss(student_log_probs[data.train_mask], + teacher_probs[data.train_mask]) * (temperature ** 2) + + # Student's own classification loss + class_loss = ce_loss(student_logits[data.train_mask], + data.y[data.train_mask]) + + # Combined loss (weighted sum) + loss = 0.7 * distill_loss + 0.3 * class_loss + + loss.backward() + optimizer.step() + + # Track best student model + if epoch % 5 == 0: + student_model.eval() + with torch.no_grad(): + out = student_model(data.x, data.edge_index) + pred = out[data.test_mask].argmax(dim=1) + correct = (pred == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + acc = correct / total if total > 0 else 0 + + if acc > best_acc: + best_acc = acc + best_student = copy.deepcopy(student_model) + + print(f"Distillation Epoch {epoch}: Loss={loss.item():.4f}, " + f"Distill={distill_loss.item():.4f}, Class={class_loss.item():.4f}, " + f"Acc={acc:.4f}") + + print(f"Distillation completed. Best student accuracy: {best_acc:.4f}") + return best_student if best_student is not None else student_model + + def _train_different_architecture(self, data): + """Train a model with different architecture for negative GNNs""" + # Use opposite architecture of target model + if isinstance(self.target_gnn, GCNConv): + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2, + heads=4 + ).to(self.device) + else: + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + return self._train_gnn_model_with_data(model, data, epochs=50) + + def _train_on_subset(self, data, subset_ratio=0.7): + """Train on a subset of the data""" + # Create subset mask + num_train = int(data.train_mask.sum().item() * subset_ratio) + subset_mask = torch.zeros_like(data.train_mask) + train_indices = data.train_mask.nonzero(as_tuple=True)[0] + selected_indices = random.sample(range(len(train_indices)), min(num_train, len(train_indices))) + subset_mask[train_indices[selected_indices]] = True + + # Create subset data + subset_data = PyGData( + x=data.x, + edge_index=data.edge_index, + y=data.y, + train_mask=subset_mask, + test_mask=data.test_mask + ) + # Use opposite architecture of target model + if isinstance(self.target_gnn, GCNConv): + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2, + heads=4 + ).to(self.device) + else: + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + return self._train_gnn_model_with_data(model, subset_data, epochs=50) + + def _train_gnn_model_with_data(self, model, data, epochs=100): + """Train a specific model on specific data""" + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + loss.backward() + optimizer.step() + + if epoch % 10 == 0: + acc = self._evaluate_model(model, data) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + return best_model + + def _initialize_graph_fingerprints(self): + """Initialize graph fingerprints for node classification task""" + fingerprints = [] + + for _ in range(self.num_fingerprints): + # Initialize random graph using Erdos-Renyi model + edge_index = erdos_renyi_graph(self.fingerprint_nodes, 0.3) + + # Initialize node features + if self.features is not None: + x = torch.randn(self.fingerprint_nodes, self.features.size(1)) + else: + x = torch.randn(self.fingerprint_nodes, 16) # Default feature dimension + + # Create PyG Data object + fingerprint_data = PyGData(x=x, edge_index=edge_index) + fingerprints.append(fingerprint_data) + + return fingerprints + + def _get_output_dimension(self, model, fingerprint_data): + """Get the output dimension of a model for a given fingerprint""" + model.eval() + with torch.no_grad(): + output = model(fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device)) + return output.size(1) + + def _joint_learning(self, epochs=100, fingerprint_lr=0.01, univerifier_lr=0.001): + """Joint learning of graph fingerprints and Univerifier""" + fingerprint_optimizer = optim.Adam([fp.x for fp in self.graph_fingerprints] + + [fp.edge_index for fp in self.graph_fingerprints], + lr=fingerprint_lr) + univerifier_optimizer = optim.Adam(self.univerifier.parameters(), lr=univerifier_lr) + criterion = nn.CrossEntropyLoss() + + # Prepare all models for training + all_models = [self.target_gnn] + self.positive_gnns + self.negative_gnns + labels = torch.cat([ + torch.ones(len(self.positive_gnns) + 1), # Target + positive models + torch.zeros(len(self.negative_gnns)) # Negative models + ]).long().to(self.device) + + for epoch in range(epochs): + # Forward pass through all models + all_outputs = [] + for model in all_models: + model_outputs = [] + for fingerprint in self.graph_fingerprints: + model.eval() + with torch.no_grad(): + output = model(fingerprint.x.to(self.device), + fingerprint.edge_index.to(self.device)) + model_outputs.append(output) + + # Concatenate all fingerprint outputs + concatenated = torch.cat(model_outputs, dim=0).view(1, -1) + all_outputs.append(concatenated) + + # Stack all outputs + all_outputs = torch.cat(all_outputs, dim=0) + + # Univerifier prediction + univerifier_out = self.univerifier(all_outputs) + + # Calculate loss + loss = criterion(univerifier_out, labels) + + # Backward pass + fingerprint_optimizer.zero_grad() + univerifier_optimizer.zero_grad() + loss.backward() + + # Update fingerprints and Univerifier + fingerprint_optimizer.step() + univerifier_optimizer.step() + + if epoch % 10 == 0: + print(f"Epoch {epoch}, Loss: {loss.item():.4f}") + + def _verify_ownership(self, suspect_model): + """Verify if a suspect model is pirated from the target model""" + # Get outputs for all fingerprints + target_outputs = [] + suspect_outputs = [] + + for fingerprint in self.graph_fingerprints: + self.target_gnn.eval() + suspect_model.eval() + + with torch.no_grad(): + target_out = self.target_gnn(fingerprint.x.to(self.device), + fingerprint.edge_index.to(self.device)) + suspect_out = suspect_model(fingerprint.x.to(self.device), + fingerprint.edge_index.to(self.device)) + + target_outputs.append(target_out) + suspect_outputs.append(suspect_out) + + # Concatenate outputs + target_concat = torch.cat(target_outputs, dim=0).view(1, -1) + suspect_concat = torch.cat(suspect_outputs, dim=0).view(1, -1) + + # Get Univerifier prediction + self.univerifier.eval() + with torch.no_grad(): + prediction = self.univerifier(suspect_concat) + confidence = prediction[0, 1].item() # Probability of being pirated + + return confidence > self.lambda_threshold, confidence + + def _evaluate_model(self, model, data): + """Evaluate model accuracy""" + model.eval() + with torch.no_grad(): + out = model(data.x.to(self.device), data.edge_index.to(self.device)) + pred = out.argmax(dim=1) + correct = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + return correct / total if total > 0 else 0 + + + # def _joint_learning_alternating(self, e1=5, e2=5, alpha=0.01, beta=0.001): + # """ + # Implementation of Joint Learning Approach + + # Parameters: + # e1: number of fingerprint update epochs + # e2: number of Univerifier update epochs + # alpha: learning rate for fingerprints + # beta: learning rate for Univerifier + # """ + # flag = 0 # 0: update fingerprints, 1: update Univerifier + # convergence_threshold = 1e-4 + # prev_loss = float('inf') + # convergence_count = 0 + + # print("Starting alternating optimization...") + + # while convergence_count < 3: # Converge if loss doesn't improve for 3 cycles + # # Compute total loss L (lines 4-10) + # L = 0 + # all_models = [self.target_gnn] + self.positive_gnns + self.negative_gnns + + # for model_idx, model in enumerate(all_models): + # # Get concatenated outputs from all fingerprints for this model + # fingerprint_outputs = [] + # for fingerprint in self.graph_fingerprints: + # model.eval() + # with torch.no_grad(): + # output = model(fingerprint.x.to(self.device), + # fingerprint.edge_index.to(self.device)) + # fingerprint_outputs.append(output) + + # # Concatenate all fingerprint outputs (flattened) + # concatenated = torch.cat([out.view(-1) for out in fingerprint_outputs]).unsqueeze(0) + + # # Get Univerifier prediction + # univerifier_out = self.univerifier(concatenated) + # o_plus = univerifier_out[0, 1] # Probability of being pirated + + # # Accumulate loss according to algorithm + # if model_idx < len(self.positive_gnns) + 1: # Target or positive model + # L += torch.log(o_plus + 1e-10) + # else: # Negative model + # L += torch.log(1 - o_plus + 1e-10) + + # current_loss = -L.item() # Negative since we're maximizing + + # # Check convergence + # if abs(prev_loss - current_loss) < convergence_threshold: + # convergence_count += 1 + # else: + # convergence_count = 0 + # prev_loss = current_loss + + # print(f"Cycle loss: {current_loss:.6f}, Flag: {flag}, Convergence count: {convergence_count}") + + # # Alternating optimization (lines 11-21) + # if flag == 0: + # # Update fingerprints for e1 epochs + # print(f"Updating fingerprints for {e1} epochs...") + # for e in range(e1): + # self._update_fingerprints_single_epoch(alpha) + # flag = 1 + # else: + # # Update Univerifier for e2 epochs + # print(f"Updating Univerifier for {e2} epochs...") + # for e in range(e2): + # self._update_univerifier_single_epoch(beta) + # flag = 0 + + # print("Alternating optimization converged!") + + + + # def _update_fingerprints_single_epoch(self, alpha=0.01, top_k_edges=10): + # """Update fingerprints for one epoch""" + # attribute_ranges = self._get_attribute_ranges() + + # for fingerprint in self.graph_fingerprints: + # # Make copies that require gradients + # x_tensor = fingerprint.x.clone().detach().requires_grad_(True) + # adj_dense = to_dense_adj(fingerprint.edge_index, + # max_num_nodes=self.fingerprint_nodes)[0] + # adj_tensor = adj_dense.clone().detach().requires_grad_(True) + + # # Compute loss for this fingerprint + # loss = self._compute_single_fingerprint_loss(x_tensor, adj_tensor) + + # # Compute gradients + # if x_tensor.grad is not None: + # x_tensor.grad.zero_() + # if adj_tensor.grad is not None: + # adj_tensor.grad.zero_() + + # loss.backward() + + # # Update node attributes with gradient and clipping + # if x_tensor.grad is not None: + # new_x = x_tensor + alpha * x_tensor.grad + # fingerprint.x = self._clip_attributes(new_x.detach(), attribute_ranges) + + # # Update adjacency matrix using paper's discrete method + # if adj_tensor.grad is not None: + # self._update_adjacency_discrete(fingerprint, adj_tensor, alpha, top_k_edges) + + # def _update_univerifier_single_epoch(self, beta=0.001): + # """Update Univerifier for one epoch""" + # optimizer = optim.Adam(self.univerifier.parameters(), lr=beta) + + # # Compute loss for all models + # L = 0 + # all_models = [self.target_gnn] + self.positive_gnns + self.negative_gnns + + # for model_idx, model in enumerate(all_models): + # # Get concatenated outputs from all fingerprints + # fingerprint_outputs = [] + # for fingerprint in self.graph_fingerprints: + # model.eval() + # with torch.no_grad(): + # output = model(fingerprint.x.to(self.device), + # fingerprint.edge_index.to(self.device)) + # fingerprint_outputs.append(output) + + # concatenated = torch.cat([out.view(-1) for out in fingerprint_outputs]).unsqueeze(0) + + # # Univerifier prediction + # univerifier_out = self.univerifier(concatenated) + # o_plus = univerifier_out[0, 1] + + # # Accumulate loss + # if model_idx < len(self.positive_gnns) + 1: # Target or positive + # L += torch.log(o_plus + 1e-10) + # else: # Negative + # L += torch.log(1 - o_plus + 1e-10) + + # # Optimization step + # optimizer.zero_grad() + # (-L).backward() # Minimize negative log likelihood + # optimizer.step() + + # def _update_adjacency_discrete(self, fingerprint, adj_tensor, alpha, top_k_edges): + # """Update adjacency matrix using paper's discrete optimization rules""" + # adj_grad = adj_tensor.grad + + # if adj_grad is None: + # return + + # # Get top-K edges with largest absolute gradient values + # flat_grad = adj_grad.view(-1) + # flat_abs_grad = torch.abs(flat_grad) + # top_values, top_indices = torch.topk(flat_abs_grad, min(top_k_edges, flat_abs_grad.numel())) + + # current_adj = adj_tensor.detach().clone() + # current_adj.requires_grad_(False) + + # for idx in top_indices: + # if top_values[idx] < 1e-8: # Skip very small gradients + # continue + + # # Convert flat index to (u, v) coordinates + # u = idx // self.fingerprint_nodes + # v = idx % self.fingerprint_nodes + + # if u >= self.fingerprint_nodes or v >= self.fingerprint_nodes: + # continue + + # grad_value = adj_grad[u, v].item() + # current_value = current_adj[u, v].item() + + # # Apply paper's rules: + # # 1. If edge exists and gradient <= 0: remove edge + # # 2. If edge doesn't exist and gradient >= 0: add edge + # if current_value > 0.5: # Edge exists + # if grad_value <= 0: + # current_adj[u, v] = 0 + # current_adj[v, u] = 0 # Undirected graph + # else: # Edge doesn't exist + # if grad_value >= 0: + # current_adj[u, v] = 1 + # current_adj[v, u] = 1 # Undirected graph + + # # Convert back to sparse and update fingerprint + # new_edge_index = dense_to_sparse(current_adj)[0] + # fingerprint.edge_index = new_edge_index + + # def _compute_single_fingerprint_loss(self, x_tensor, adj_tensor): + # """Compute loss contribution from a single fingerprint""" + # edge_index_sparse = dense_to_sparse(adj_tensor)[0] + # loss = 0 + + # # All models + # all_models = [self.target_gnn] + self.positive_gnns + self.negative_gnns + + # for model_idx, model in enumerate(all_models): + # model_out = model(x_tensor, edge_index_sparse) + # concat_out = model_out.view(1, -1) + # univerifier_out = self.univerifier(concat_out) + # o_plus = univerifier_out[0, 1] + + # if model_idx < len(self.positive_gnns) + 1: # Target or positive + # loss += torch.log(o_plus + 1e-10) + # else: # Negative + # loss += torch.log(1 - o_plus + 1e-10) + + # return loss + + +# GNN Model Definitions +class GCNConvGNN(nn.Module): + """GCN-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3): + super(GCNConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GCNConv(in_channels, hidden_channels)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + + # Output layer + self.convs.append(GCNConv(hidden_channels, out_channels)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, training=self.training, p=0.5) + return x + +class GATConvGNN(nn.Module): + + """GAT-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, heads=4): + super(GATConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GATConv(in_channels, hidden_channels, heads=heads)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads)) + + # Output layer + self.convs.append(GATConv(hidden_channels * heads, out_channels, heads=1)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.elu(x) + x = F.dropout(x, training=self.training, p=0.6) + return x + + + + diff --git a/models/defense/GNN_Fingers_Examples/GNN_Fingers_Batching.py b/models/defense/GNN_Fingers_Examples/GNN_Fingers_Batching.py new file mode 100644 index 0000000..fd80aa8 --- /dev/null +++ b/models/defense/GNN_Fingers_Examples/GNN_Fingers_Batching.py @@ -0,0 +1,1095 @@ +import sys +sys.path.append('.') +import importlib +import numpy as np +from tqdm import tqdm +import copy +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from datasets import Cora, PubMed +#from models.nn import GraphSAGE +#from dgl.dataloading import NeighborSampler, NodeCollator +from torch_geometric.nn import GCNConv, GATConv +from torch_geometric.utils import erdos_renyi_graph, to_dense_adj, dense_to_sparse +from torch_geometric.loader import NeighborLoader +from torch_geometric.data import Data as PyGData +from models.defense.base import BaseDefense + + +class LearnableGraphFingerprint(nn.Module): + """ + A learnable graph fingerprint that converts PyG Data components to learnable parameters + """ + def __init__(self, num_nodes, feature_dim): + super(LearnableGraphFingerprint, self).__init__() + self.num_nodes = num_nodes + self.feature_dim = feature_dim + + # Initialize node features as learnable parameters + self.x = nn.Parameter(torch.randn(num_nodes, feature_dim)) + + # Initialize adjacency matrix as learnable parameters (dense representation) + self.adj_matrix = nn.Parameter(torch.zeros(num_nodes, num_nodes)) + + @classmethod + def from_pyg_data(cls, x, edge_index, num_nodes, feature_dim): + """Create learnable fingerprint from PyG Data components""" + fingerprint = cls(num_nodes, feature_dim) + + # Initialize node features + fingerprint.x.data = x.clone() + + # Initialize adjacency matrix from edge_index + with torch.no_grad(): + # Convert sparse edge_index to dense adjacency matrix + dense_adj = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0] + fingerprint.adj_matrix.data = dense_adj + return fingerprint + + def forward(self, return_pyg_data=True): + """Return the graph structure using straight-through estimator""" + # Get discrete adjacency matrix (0.0 or 1.0) using straight-through estimator + adj_binary = (self.adj_matrix > 0.5).float() + adj_binary_st = adj_binary + (self.adj_matrix - self.adj_matrix.detach()) + + # Convert dense adjacency to sparse edge_index + edge_index, edge_attr = dense_to_sparse(adj_binary_st) + + if return_pyg_data: + # Return as PyG Data object + return PyGData(x=self.x, edge_index=edge_index, edge_attr=edge_attr) + else: + # Return raw components + return self.x, edge_index, adj_binary + + def get_discrete_adjacency(self): + """Get the actual discrete adjacency matrix (for verification)""" + with torch.no_grad(): + return (self.adj_matrix > 0.5).float() + + def to_pyg_data(self): + """Convert to PyG Data object (without gradient tracking)""" + with torch.no_grad(): + adj_binary = (self.adj_matrix > 0.5).float() + edge_index, edge_attr = dense_to_sparse(adj_binary) + return PyGData(x=self.x.detach(), edge_index=edge_index, edge_attr=edge_attr) + + def get_original_components(self): + """Get the original PyG Data components (for debugging)""" + with torch.no_grad(): + adj_binary = (self.adj_matrix > 0.5).float() + edge_index, edge_attr = dense_to_sparse(adj_binary) + return self.x.detach(), edge_index + + +class Univerifier(nn.Module): + """ + Unified Verification Mechanism - Binary classifier that takes concatenated outputs + from suspect models and predicts whether they are pirated or irrelevant. + """ + def __init__(self, input_dim, hidden_dims=[128, 64,32,16,8,4]): + super(Univerifier, self).__init__() + layers = [] + prev_dim = input_dim + + for hidden_dim in hidden_dims: + layers.append(nn.Linear(prev_dim, hidden_dim)) + layers.append(nn.LeakyReLU()) + layers.append(nn.Dropout(0.1)) + prev_dim = hidden_dim + + layers.append(nn.Linear(prev_dim, 2)) + self.classifier = nn.Sequential(*layers) + + def forward(self, x): + return F.softmax(self.classifier(x), dim=1) + +class GNNFingers(BaseDefense): + """ + GNNFingers: A Fingerprinting Framework for Verifying Ownerships of Graph Neural Networks + Implementation based on the paper by You et al. (2024) + """ + supported_api_types = {"dgl"} + + def __init__(self, dataset, attack_node_fraction=0.2, device=None, attack_name=None, + num_fingerprints=64, fingerprint_nodes=32, lambda_threshold=0.7, + fingerprint_update_epochs=5, univerifier_update_epochs=3, + fingerprint_lr=0.01, univerifier_lr=0.001, top_k_ratio=0.1, + epochs=100, batch_size=32, num_neighbors=[10, 5]): + """ + Initialize GNNFingers defense framework + + Parameters + ---------- + dataset : Dataset + The original dataset containing the graph to defend + attack_node_fraction : float + Fraction of nodes to consider for attack + device : torch.device + Device to run computations on + attack_name : str + Name of the attack class to use + num_fingerprints : int + Number of graph fingerprints to generate + fingerprint_nodes : int + Number of nodes in each fingerprint graph + lambda_threshold : float + Threshold for Univerifier classification + fingerprint_update_epochs: int + Number of Epochs to update Fingerprint + univerifier_update_epochs: int + Number of Epochs to update Univerifier + fingerprint_lr: float + Learning rate for fingerprint update + univerifier_lr: float + Learning rate for Univerifier update + top_k_ratio: float + top k gradients of fingerprint adjacency matrix to select + epochs: int + total number of epochs to run experiment + batch_size : int + Batch size for training + num_neighbors : list + Number of neighbors to sample at each layer + """ + super().__init__(dataset, attack_node_fraction, device) + self.attack_name = attack_name or "ModelExtractionAttack0" + self.dataset = dataset + self.graph = dataset.graph_data + + # Extract dataset properties + self.node_number = dataset.num_nodes + self.feature_number = dataset.num_features + self.label_number = dataset.num_classes + self.attack_node_number = int(self.node_number * attack_node_fraction) + + # Training parameters + self.batch_size = batch_size + self.num_neighbors = num_neighbors + + # Convert DGL to PyG data + self.pyg_data = self._dgl_to_pyg(self.graph) + + # Extract features and labels + self.features = self.pyg_data.x + self.labels = self.pyg_data.y + + # Extract masks + self.train_mask = self.pyg_data.train_mask + self.test_mask = self.pyg_data.test_mask + + # GNNFingers parameters + self.num_fingerprints = num_fingerprints + self.fingerprint_nodes = fingerprint_nodes + self.lambda_threshold = lambda_threshold + + # Initialize components + self.target_gnn = None + self.positive_gnns = [] # Pirated GNNs + self.negative_gnns = [] # Irrelevant GNNs + self.graph_fingerprints = None + self.univerifier = None + + self.fingerprint_lr = fingerprint_lr + self.fingerprint_update_epochs = fingerprint_update_epochs + self.univerifier_update_epochs = univerifier_update_epochs + self.univerifier_lr = univerifier_lr + self.top_k_ratio = top_k_ratio + self.epochs = epochs + + # Move tensors to device + if self.device != 'cpu': + self.graph = self.graph.to(self.device) + self.features = self.features.to(self.device) + self.labels = self.labels.to(self.device) + self.train_mask = self.train_mask.to(self.device) + self.test_mask = self.test_mask.to(self.device) + + def _dgl_to_pyg(self, dgl_graph): + """Convert DGL graph to PyTorch Geometric Data object""" + # Extract edge indices + edge_index = torch.stack(dgl_graph.edges()) + x = dgl_graph.ndata.get('feat') + y = dgl_graph.ndata.get('label') + + train_mask = dgl_graph.ndata.get('train_mask') + val_mask = dgl_graph.ndata.get('val_mask') + test_mask = dgl_graph.ndata.get('test_mask') + + data = PyGData(x=x, edge_index=edge_index, y=y, + train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) + + return data + + def _create_dataloaders(self, graph_data, batch_size=None, num_neighbors=None): + """ + Create train and test dataloaders for PyG data + + Parameters + ---------- + graph_data : PyG Data + The graph data to create loaders for + batch_size : int, optional + Batch size (defaults to self.batch_size) + num_neighbors : list, optional + Number of neighbors to sample (defaults to self.num_neighbors) + + Returns + ------- + train_loader : NeighborLoader + Training dataloader + test_loader : NeighborLoader + Test dataloader + """ + batch_size = batch_size or self.batch_size + num_neighbors = num_neighbors or self.num_neighbors + + train_loader = NeighborLoader( + graph_data, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=True, + input_nodes=graph_data.train_mask, + ) + + test_loader = NeighborLoader( + graph_data, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=False, + input_nodes=graph_data.test_mask, + ) + + return train_loader, test_loader + + def defend(self, attack_name=None): + """ + Main defense workflow for GNNFingers + """ + attack_name = attack_name or self.attack_name + AttackClass = self._get_attack_class(attack_name) + print(f"Using attack method: {attack_name}") + + # Step 1: Train target model + print("Training target GNN...") + self.target_gnn = self._train_gnn_model(self.pyg_data, "Target GNN") + + # Step 2: Prepare positive and negative GNNs + print("Preparing positive (pirated) GNNs...") + self.positive_gnns = self._prepare_positive_gnns(self.target_gnn, num_models=50) + + print("Preparing negative (irrelevant) GNNs...") + self.negative_gnns = self._prepare_negative_gnns(num_models=50) + + # Step 3: Initialize graph fingerprints + print("Initializing graph fingerprints...") + self.graph_fingerprints = self._initialize_graph_fingerprints() + + # Step 4: Initialize Univerifier + output_dim = self._get_output_dimension(self.target_gnn, self.graph_fingerprints[0]) + self.univerifier = Univerifier(input_dim=output_dim * self.num_fingerprints) + self.univerifier = self.univerifier.to(self.device) + + # Step 5: Joint learning of fingerprints and Univerifier + print("Joint learning of fingerprints and Univerifier...") + self._joint_learning_alternating() + + # Step 6: Attack target model + print("Attacking target model...") + attack = AttackClass(self.dataset, attack_node_fraction=0.2) + attack_results = attack.attack() + suspect_model = attack.net2 if hasattr(attack, 'net2') else None + + # Step 7: Verify ownership + if suspect_model is not None: + print("Verifying ownership of suspect model...") + verification_result = self._verify_ownership(suspect_model) + print(f"Ownership verification result: {verification_result}") + + return { + "attack_results": attack_results, + "verification_result": verification_result, + "target_accuracy": self._evaluate_model(self.target_gnn, self.pyg_data), + "suspect_accuracy": self._evaluate_model(suspect_model, self.pyg_data) + } + + return {"attack_results": attack_results, "verification_result": "No suspect model found"} + + def _train_gnn_model(self, data, model_name="GNN", epochs=100): + """Train a GNN model on the given data using batched training""" + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=128, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + # Create dataloaders using the helper function + train_loader, test_loader = self._create_dataloaders(data) + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + # Forward pass + out = model(batch.x, batch.edge_index) + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + # Backward pass + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Evaluate on test set + if epoch % 10 == 0: + test_acc = self._evaluate_model_with_loader(model, test_loader) + + if test_acc > best_acc: + best_acc = test_acc + best_model = copy.deepcopy(model) + + print(f"{model_name} Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={test_acc:.4f}") + + print(f"{model_name} trained with best accuracy: {best_acc:.4f}") + return best_model + + def _evaluate_model_with_loader(self, model, test_loader): + """Evaluate model accuracy using a test dataloader""" + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for batch in test_loader: + batch = batch.to(self.device) + out = model(batch.x, batch.edge_index) + pred = out.argmax(dim=1) + + # Only count test nodes + test_mask = batch.test_mask if hasattr(batch, 'test_mask') else torch.ones(batch.num_nodes, dtype=bool) + correct += (pred[test_mask] == batch.y[test_mask]).sum().item() + total += test_mask.sum().item() + + return correct / total if total > 0 else 0 + + def _prepare_positive_gnns(self, target_model, num_models=50): + """Prepare pirated GNNs using obfuscation techniques""" + positive_models = [] + + for i in range(num_models): + # Apply different obfuscation techniques + if i % 3 == 0: + # Fine-tuning with batched training + layers_to_finetune = random.randint(1, 3) + model = self._fine_tune_model(copy.deepcopy(target_model), self.pyg_data, + epochs=10, num_layers_to_finetune=layers_to_finetune) + elif i % 3 == 1: + # Partial retraining with batched training + layers_to_retrain = random.randint(1, 3) + model = self._partial_retrain_model(copy.deepcopy(target_model), self.pyg_data, + epochs=15, num_layers_to_retrain=layers_to_retrain) + else: + # Distillation with batched training + temperature = random.choice([1.5, 2.0, 3.0, 4.0]) + model = self._distill_model(target_model, self.pyg_data, + epochs=30, temperature=temperature) + + positive_models.append(model) + + return positive_models + + def _prepare_negative_gnns(self, num_models=50): + """Prepare irrelevant GNNs""" + negative_models = [] + + for i in range(num_models): + # Train from scratch with different architectures or data + if i % 2 == 0: + # Different architecture + model = self._train_different_architecture(self.pyg_data) + else: + # Different training data (subset) + model = self._train_on_subset(self.pyg_data) + + negative_models.append(model) + + return negative_models + + def _fine_tune_model(self, model, data, epochs=10, num_layers_to_finetune=1): + """Fine-tune a model using batched training""" + # Freeze all layers initially + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze the last K layers for fine-tuning + if hasattr(model, 'convs'): + total_layers = len(model.convs) + layers_to_finetune = min(num_layers_to_finetune, total_layers) + + for i in range(total_layers - layers_to_finetune, total_layers): + for param in model.convs[i].parameters(): + param.requires_grad = True + + # Create dataloader using helper function + train_loader, _ = self._create_dataloaders(data) + + # Only optimize parameters that require gradients + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = optim.Adam(trainable_params, lr=0.001) + criterion = nn.CrossEntropyLoss() + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + out = model(batch.x, batch.edge_index) + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Unfreeze all parameters for future use + for param in model.parameters(): + param.requires_grad = True + + return model + + def _partial_retrain_model(self, model, data, epochs=10, num_layers_to_retrain=2): + """Partially retrain a model with random initialization of K layers before resuming training""" + # Randomly initialize selected K layers + if hasattr(model, 'convs'): + # For models with convs attribute (like GCNConvGNN, GATConvGNN) + total_layers = len(model.convs) + layers_to_retrain = min(num_layers_to_retrain, total_layers) + + # Randomly select K layers to retrain + layer_indices = random.sample(range(total_layers), layers_to_retrain) + + print(f"Partially retraining layers: {layer_indices}") + + for idx in layer_indices: + model.convs[idx].reset_parameters() # Random reinitialization + + # Train the entire model (both retrained and original layers) + train_loader, test_loader = self._create_dataloaders(data) + optimizer = optim.Adam(model.parameters(), lr=0.01) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + out = model(batch.x, batch.edge_index) + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Track best model + if epoch % 5 == 0: + acc = self._evaluate_model_with_loader(model, test_loader) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + print(f"Partial retraining completed. Best accuracy: {best_acc:.4f}") + return best_model if best_model is not None else model + + def _distill_model(self, teacher_model, data, epochs=30, temperature=2.0): + """Distill knowledge using batched training""" + # Create student model with different architecture + if isinstance(teacher_model, GCNConvGNN): + # If teacher is GCN, use GAT as student + student_model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=96, # Different hidden size + out_channels=self.label_number, + num_layers=2, # Different number of layers + heads=3 # Different number of heads + ).to(self.device) + else: + # If teacher is GAT or other, use GCN as student + student_model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, # Different hidden size + out_channels=self.label_number, + num_layers=3 # Different number of layers + ).to(self.device) + + # Create dataloader using helper function + train_loader, test_loader = self._create_dataloaders(data) + + optimizer = optim.Adam(student_model.parameters(), lr=0.01, weight_decay=1e-4) + + # Combined loss: KL divergence for distillation + cross entropy for ground truth + kl_loss = nn.KLDivLoss(reduction='batchmean') + ce_loss = nn.CrossEntropyLoss() + + teacher_model.eval() + + best_acc = 0 + best_student = None + + for epoch in range(epochs): + student_model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + # Get teacher predictions (with temperature scaling) + with torch.no_grad(): + teacher_logits = teacher_model(batch.x, batch.edge_index) + teacher_probs = F.softmax(teacher_logits / temperature, dim=1) + + # Get student predictions + student_logits = student_model(batch.x, batch.edge_index) + student_log_probs = F.log_softmax(student_logits / temperature, dim=1) + + # Distillation loss (KL divergence between teacher and student) + distill_loss = kl_loss(student_log_probs[batch.train_mask], + teacher_probs[batch.train_mask]) * (temperature ** 2) + + # Student's own classification loss + class_loss = ce_loss(student_logits[batch.train_mask], + batch.y[batch.train_mask]) + + # Combined loss (weighted sum) + loss = 0.7 * distill_loss + 0.3 * class_loss + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Track best student model + if epoch % 5 == 0: + student_model.eval() + test_acc = self._evaluate_model_with_loader(student_model, test_loader) + + if test_acc > best_acc: + best_acc = test_acc + best_student = copy.deepcopy(student_model) + + print(f"Distillation Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={test_acc:.4f}") + + print(f"Distillation completed. Best student accuracy: {best_acc:.4f}") + return best_student if best_student is not None else student_model + + def _train_different_architecture(self, data): + """Train a model with different architecture for negative GNNs""" + # Use opposite architecture of target model + if isinstance(self.target_gnn, GCNConvGNN): + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2, + heads=4 + ).to(self.device) + else: + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + return self._train_gnn_model_with_data(model, data, epochs=50) + + def _train_on_subset(self, data, subset_ratio=0.7): + """Train on a subset of the data""" + # Create subset mask + num_train = int(data.train_mask.sum().item() * subset_ratio) + subset_mask = torch.zeros_like(data.train_mask) + train_indices = data.train_mask.nonzero(as_tuple=True)[0] + selected_indices = random.sample(range(len(train_indices)), min(num_train, len(train_indices))) + subset_mask[train_indices[selected_indices]] = True + + # Create subset data + subset_data = PyGData( + x=data.x, + edge_index=data.edge_index, + y=data.y, + train_mask=subset_mask, + test_mask=data.test_mask + ) + + # Use opposite architecture of target model + if isinstance(self.target_gnn, GCNConvGNN): + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2 + ).to(self.device) + else: + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3, + heads=4 + ).to(self.device) + + return self._train_gnn_model_with_data(model, subset_data, epochs=50) + + def _train_gnn_model_with_data(self, model, data, epochs=100): + """Train a specific model on specific data using batched training""" + # Create dataloaders using helper function + train_loader, test_loader = self._create_dataloaders(data) + + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + out = model(batch.x, batch.edge_index) + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + if epoch % 10 == 0: + acc = self._evaluate_model_with_loader(model, test_loader) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + return best_model + + def _initialize_graph_fingerprints(self): + """Initialize graph fingerprints as learnable parameters from PyG Data""" + fingerprints = nn.ModuleList() # Use ModuleList to properly register parameters + feature_dim = self.features.size(1) if self.features is not None else 16 + + for _ in range(self.num_fingerprints): + # Initialize random graph using Erdos-Renyi model + edge_index = erdos_renyi_graph(self.fingerprint_nodes, 0.3) + + # Initialize node features + if self.features is not None: + x = torch.randn(self.fingerprint_nodes, self.features.size(1)) + else: + x = torch.randn(self.fingerprint_nodes, 16) + + # Convert to learnable fingerprint + fingerprint = LearnableGraphFingerprint.from_pyg_data( + x, edge_index, self.fingerprint_nodes, feature_dim + ).to(self.device) + fingerprints.append(fingerprint) + + return fingerprints + + def _get_output_dimension(self, model, fingerprint_data): + """Get the output dimension of a model for a given fingerprint""" + model.eval() + with torch.no_grad(): + output = model(fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device)) + return output.size(1) + + def _verify_ownership(self, suspect_model): + """Verify if a suspect model is pirated from the target model""" + # Get outputs for all fingerprints + target_outputs = [] + suspect_outputs = [] + + for fingerprint in self.graph_fingerprints: + self.target_gnn.eval() + suspect_model.eval() + # Get fingerprint as PyG Data object (without gradient tracking) + fingerprint = fingerprint.to_pyg_data() + + with torch.no_grad(): + target_out = self.target_gnn(fingerprint.x.to(self.device), + fingerprint.edge_index.to(self.device)) + suspect_out = suspect_model(fingerprint.x.to(self.device), + fingerprint.edge_index.to(self.device)) + + target_outputs.append(target_out) + suspect_outputs.append(suspect_out) + + # Concatenate outputs + target_concat = torch.cat(target_outputs, dim=0).view(1, -1) + suspect_concat = torch.cat(suspect_outputs, dim=0).view(1, -1) + + # Get Univerifier prediction + self.univerifier.eval() + with torch.no_grad(): + prediction = self.univerifier(suspect_concat) + confidence = prediction[0, 1].item() # Probability of being pirated + + return confidence > self.lambda_threshold, confidence + + def _evaluate_model(self, model, data): + """Evaluate model accuracy""" + model.eval() + with torch.no_grad(): + out = model(data.x.to(self.device), data.edge_index.to(self.device)) + pred = out.argmax(dim=1) + correct = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + return correct / total if total > 0 else 0 + + def _update_adjacency_discrete(self, fingerprint, grad_adj): + """ + Update discrete adjacency matrix based on gradients + """ + # Get current discrete adjacency + current_adj = fingerprint.get_discrete_adjacency() + + # Get absolute gradient values and flatten + grad_abs = torch.abs(grad_adj) + grad_abs_flat = grad_abs.view(-1) + + # Determine top-K edges to consider for flipping + k = int(self.top_k_ratio * self.fingerprint_nodes * self.fingerprint_nodes) + topk_values, topk_indices = torch.topk(grad_abs_flat, k) + + # Convert flat indices to row, col indices + rows = topk_indices // self.fingerprint_nodes + cols = topk_indices % self.fingerprint_nodes + + # Update edges based on gradient signs + with torch.no_grad(): + for idx in range(k): + row, col = rows[idx], cols[idx] + grad_val = grad_adj[row, col] + + # Current edge existence (0 or 1) + current_edge = current_adj[row, col] + + # Apply update rules: + if current_edge > 0.5 and grad_val <= 0: + # Edge exists and gradient is negative → remove edge + fingerprint.adj_matrix.data[row, col] = 0.0 + elif current_edge < 0.5 and grad_val >= 0: + # Edge doesn't exist and gradient is positive → add edge + fingerprint.adj_matrix.data[row, col] = 1.0 + + def _update_fingerprints_discrete(self, loss, top_k_ratio=0.1): + """ + Update graph fingerprints using gradients + """ + # Compute gradients for all fingerprints + gradients_adj = [] + gradients_x = [] + + for fingerprint in self.graph_fingerprints: + # Compute gradients for adjacency matrix + grad_adj = torch.autograd.grad( + loss, fingerprint.adj_matrix, + retain_graph=True, create_graph=False + )[0] + + # Compute gradients for node features + grad_x = torch.autograd.grad( + loss, fingerprint.x, + retain_graph=True, create_graph=False + )[0] + + gradients_adj.append(grad_adj) + gradients_x.append(grad_x) + + # Update each fingerprint + for i, fingerprint in enumerate(self.graph_fingerprints): + grad_adj = gradients_adj[i] + grad_x = gradients_x[i] + + # Update node features with clipping + with torch.no_grad(): + fingerprint.x.data += self.fingerprint_lr * grad_x + + # Clip node features to reasonable range + if self.features is not None: + min_val = self.features.min().item() + max_val = self.features.max().item() + fingerprint.x.data = torch.clamp(fingerprint.x.data, min_val, max_val) + else: + fingerprint.x.data = torch.clamp(fingerprint.x.data, -3, 3) + + # Update adjacency matrix using discrete strategy + self._update_adjacency_discrete(fingerprint, grad_adj, top_k_ratio) + + def visualize_fingerprint_evolution(self, epoch): + """Visualize how fingerprints evolve during training""" + if epoch % 20 == 0: # Visualize every 20 epochs + print(f"\n=== Fingerprint Evolution at Epoch {epoch} ===") + + for i, fingerprint in enumerate(self.graph_fingerprints[:2]): # First 2 only + x, edge_index = fingerprint.get_original_components() + current_adj = fingerprint.get_discrete_adjacency() + + # Calculate statistics + num_edges = current_adj.sum().item() + sparsity = 1 - (num_edges / (self.fingerprint_nodes * self.fingerprint_nodes)) + + print(f"Fingerprint {i}: {num_edges} edges, sparsity: {sparsity:.3f}") + + # Feature statistics + feature_mean = x.mean().item() + feature_std = x.std().item() + print(f" Features: mean={feature_mean:.3f}, std={feature_std:.3f}") + + def _joint_learning_alternating(self): + """ + Joint learning with alternating optimization algorithm + """ + + # Prepare all models and labels + all_models = [self.target_gnn] + self.positive_gnns + self.negative_gnns + labels = torch.cat([ + torch.ones(len(self.positive_gnns) + 1), # Target + positive models + torch.zeros(len(self.negative_gnns)) # Negative models + ]).long().to(self.device) + + # Flag to alternate between fingerprint and univerifier updates + update_fingerprints = True + + for epoch in range(self.epochs): + # Forward pass through all models using the actual discrete structure + all_outputs = [] + for model in all_models: + model_outputs = [] + for fingerprint in self.graph_fingerprints: + model.eval() + + # Get fingerprint as PyG Data object (this uses straight-through estimator) + fingerprint_data = fingerprint(return_pyg_data=True) + + with torch.no_grad(): + # Pass through the model + output = model(fingerprint_data.x, fingerprint_data.edge_index) + model_outputs.append(output) + + # Concatenate all fingerprint outputs + concatenated = torch.cat(model_outputs, dim=0).view(1, -1) + all_outputs.append(concatenated) + + # Stack all outputs + all_outputs = torch.cat(all_outputs, dim=0) + + # Univerifier prediction + univerifier_out = self.univerifier(all_outputs) + + # Calculate joint loss + loss = 0 + for i, model in enumerate(all_models): + if i < len(self.positive_gnns) + 1: # Target + positive models + # log o_+(F) and log o_+(F_+) terms + loss += torch.log(univerifier_out[i, 1] + 1e-10) + else: # Negative models + # log o_-(F_-) term + loss += torch.log(1 - univerifier_out[i, 1] + 1e-10) + + loss = -loss # Negative log likelihood (minimize negative log likelihood) + + # Alternating optimization + if update_fingerprints: + # Phase 1: Update fingerprints for e1 epochs + for e in range(self.fingerprint_update_epochs): + self._update_fingerprints_discrete(loss, self.top_k_ratio) + + update_fingerprints = False + print(f"Epoch {epoch}: Updated fingerprints, Loss: {loss.item():.4f}") + + else: + # Phase 2: Update Univerifier for e2 epochs + univerifier_optimizer = optim.Adam(self.univerifier.parameters(), lr=self.univerifier_lr) + + for e in range(self.univerifier_update_epochs): + univerifier_optimizer.zero_grad() + loss.backward(retain_graph=True) + univerifier_optimizer.step() + + update_fingerprints = True + print(f"Epoch {epoch}: Updated Univerifier, Loss: {loss.item():.4f}") + + # Calculate accuracy every 10 epochs + if epoch % 10 == 0: + with torch.no_grad(): + preds = univerifier_out.argmax(dim=1) + acc = (preds == labels).float().mean().item() + + # Calculate true positive and true negative rates + tp_mask = (preds == 1) & (labels == 1) + tn_mask = (preds == 0) & (labels == 0) + + tp_rate = tp_mask.float().mean().item() if (labels == 1).sum() > 0 else 0 + tn_rate = tn_mask.float().mean().item() if (labels == 0).sum() > 0 else 0 + + print(f"Epoch {epoch}, Acc: {acc:.4f}, TP: {tp_rate:.4f}, TN: {tn_rate:.4f}") + + # Visualize fingerprint evolution + if epoch % 20 == 0: + self.visualize_fingerprint_evolution(epoch) + + + def _verify_ownership_detailed(self, suspect_model): + """Detailed verification for debugging purposes only""" + suspect_outputs = [] + target_outputs = [] + + for fingerprint in self.graph_fingerprints: + suspect_model.eval() + self.target_gnn.eval() + + fingerprint_data = fingerprint.to_pyg_data() + + with torch.no_grad(): + suspect_out = suspect_model( + fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device) + ) + suspect_outputs.append(suspect_out) + + target_out = self.target_gnn( + fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device) + ) + target_outputs.append(target_out) + + suspect_concat = torch.cat(suspect_outputs, dim=0).view(1, -1) + target_concat = torch.cat(target_outputs, dim=0).view(1, -1) + + self.univerifier.eval() + with torch.no_grad(): + suspect_prediction = self.univerifier(suspect_concat) + suspect_confidence = suspect_prediction[0, 1].item() + + target_prediction = self.univerifier(target_concat) + target_confidence = target_prediction[0, 1].item() + + output_similarity = F.cosine_similarity(suspect_concat, target_concat).item() + is_pirated = suspect_confidence > self.lambda_threshold + + # Return detailed info for debugging, but main method keeps original interface + return { + 'is_pirated': is_pirated, + 'confidence': suspect_confidence, + 'target_confidence': target_confidence, + 'output_similarity': output_similarity, + 'lambda_threshold': self.lambda_threshold + } + + +# GNN Model Definitions +class GCNConvGNN(nn.Module): + """GCN-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3): + super(GCNConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GCNConv(in_channels, hidden_channels)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + + # Output layer + self.convs.append(GCNConv(hidden_channels, out_channels)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, training=self.training, p=0.5) + return x + +class GATConvGNN(nn.Module): + + """GAT-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, heads=4): + super(GATConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GATConv(in_channels, hidden_channels, heads=heads)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads)) + + # Output layer + self.convs.append(GATConv(hidden_channels * heads, out_channels, heads=1)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.elu(x) + x = F.dropout(x, training=self.training, p=0.6) + return x + + +if __name__ == "__main__": + + # Set device + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # Load dataset + dataset = Cora(api_type="dgl", path="./data") + print(f"Loaded dataset: {dataset}") + + # Initialize defense + defense = GNNFingers( + dataset=dataset, + device=device, + num_fingerprints=32, + fingerprint_nodes=64, + epochs=100 + ) + + # Run defense + results = defense.defend() + + # Print results + print("\n=== Defense Results ===") + print(f"Target Accuracy: {results.get('target_accuracy', 0):.4f}") + print(f"Suspect Accuracy: {results.get('suspect_accuracy', 0):.4f}") + print(f"Verification Result: {results.get('verification_result', 'Unknown')}") \ No newline at end of file diff --git a/models/defense/GNN_Fingers_Examples/GNN_Fingers_Joint_Learning.py b/models/defense/GNN_Fingers_Examples/GNN_Fingers_Joint_Learning.py new file mode 100644 index 0000000..1952e3e --- /dev/null +++ b/models/defense/GNN_Fingers_Examples/GNN_Fingers_Joint_Learning.py @@ -0,0 +1,1050 @@ +import importlib +import numpy as np +from tqdm import tqdm +import copy +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from datasets import dataset +#from models.nn import GraphSAGE +#from dgl.dataloading import NeighborSampler, NodeCollator +from torch_geometric.nn import GCNConv, GATConv +from torch_geometric.utils import erdos_renyi_graph, to_dense_adj, dense_to_sparse +#from torch_geometric.loader import NeighborLoader, DataLoader +from torch_geometric.data import Data as PyGData +from models.defense.base import BaseDefense + + +class LearnableGraphFingerprint(nn.Module): + """ + A learnable graph fingerprint that converts PyG Data components to learnable parameters + """ + def __init__(self, num_nodes, feature_dim): + super(LearnableGraphFingerprint, self).__init__() + self.num_nodes = num_nodes + self.feature_dim = feature_dim + + # Initialize node features as learnable parameters + self.x = nn.Parameter(torch.randn(num_nodes, feature_dim)) + + # Initialize adjacency matrix as learnable parameters (dense representation) + self.adj_matrix = nn.Parameter(torch.zeros(num_nodes, num_nodes)) + + @classmethod + def from_pyg_data(cls, x, edge_index, num_nodes, feature_dim): + """Create learnable fingerprint from PyG Data components""" + fingerprint = cls(num_nodes, feature_dim) + + # Initialize node features + fingerprint.x.data = x.clone() + + # Initialize adjacency matrix from edge_index + with torch.no_grad(): + # Convert sparse edge_index to dense adjacency matrix + dense_adj = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0] + fingerprint.adj_matrix.data = dense_adj + return fingerprint + + def forward(self, return_pyg_data=True): + """Return the graph structure using straight-through estimator""" + # Get discrete adjacency matrix (0.0 or 1.0) using straight-through estimator + adj_binary = (self.adj_matrix > 0.5).float() + adj_binary_st = adj_binary + (self.adj_matrix - self.adj_matrix.detach()) + + # Convert dense adjacency to sparse edge_index + edge_index, edge_attr = dense_to_sparse(adj_binary_st) + + if return_pyg_data: + # Return as PyG Data object + return PyGData(x=self.x, edge_index=edge_index, edge_attr=edge_attr) + else: + # Return raw components + return self.x, edge_index, adj_binary + + def get_discrete_adjacency(self): + """Get the actual discrete adjacency matrix (for verification)""" + with torch.no_grad(): + return (self.adj_matrix > 0.5).float() + + def to_pyg_data(self): + """Convert to PyG Data object (without gradient tracking)""" + with torch.no_grad(): + adj_binary = (self.adj_matrix > 0.5).float() + edge_index, edge_attr = dense_to_sparse(adj_binary) + return PyGData(x=self.x.detach(), edge_index=edge_index, edge_attr=edge_attr) + + def get_original_components(self): + """Get the original PyG Data components (for debugging)""" + with torch.no_grad(): + adj_binary = (self.adj_matrix > 0.5).float() + edge_index, edge_attr = dense_to_sparse(adj_binary) + return self.x.detach(), edge_index + + +class Univerifier(nn.Module): + """ + Unified Verification Mechanism - Binary classifier that takes concatenated outputs + from suspect models and predicts whether they are pirated or irrelevant. + """ + def __init__(self, input_dim, hidden_dims=[128, 64,32,16,8,4]): + super(Univerifier, self).__init__() + layers = [] + prev_dim = input_dim + + for hidden_dim in hidden_dims: + layers.append(nn.Linear(prev_dim, hidden_dim)) + layers.append(nn.LeakyReLU()) + layers.append(nn.Dropout(0.1)) + prev_dim = hidden_dim + + layers.append(nn.Linear(prev_dim, 2)) + self.classifier = nn.Sequential(*layers) + + def forward(self, x): + return F.softmax(self.classifier(x), dim=1) + +class GNNFingers(BaseDefense): + """ + GNNFingers: A Fingerprinting Framework for Verifying Ownerships of Graph Neural Networks + Implementation based on the paper by You et al. (2024) + """ + supported_api_types = {"dgl"} + + def __init__(self, dataset, attack_node_fraction=0.2, device=None, attack_name=None, + num_fingerprints=64, fingerprint_nodes=32, lambda_threshold=0.7, fingerprint_update_epochs=5, + univerifier_update_epochs=3, fingerprint_lr=0.01, + univerifier_lr=0.001, top_k_ratio=0.1, epochs=100, + batch_size=32, num_neighbors=[5, 5]): + """ + Initialize GNNFingers defense framework + + Parameters + ---------- + dataset : Dataset + The original dataset containing the graph to defend + attack_node_fraction : float + Fraction of nodes to consider for attack + device : torch.device + Device to run computations on + attack_name : str + Name of the attack class to use + num_fingerprints : int + Number of graph fingerprints to generate + fingerprint_nodes : int + Number of nodes in each fingerprint graph + lambda_threshold : float + Threshold for Univerifier classification + batch_size : int + Batch size for training + num_neighbors : list + Number of neighbors to sample at each layer + """ + super().__init__(dataset, attack_node_fraction, device) + self.attack_name = attack_name or "ModelExtractionAttack0" + self.dataset = dataset + self.graph = dataset.graph_data + + # Extract dataset properties + self.node_number = dataset.num_nodes + self.feature_number = dataset.num_features + self.label_number = dataset.num_classes + self.attack_node_number = int(self.node_number * attack_node_fraction) + + # Training parameters + # self.batch_size = batch_size + # self.num_neighbors = num_neighbors + + # Convert DGL to PyG data + self.pyg_data = self._dgl_to_pyg(self.graph) + + # Extract features and labels + self.features = self.pyg_data.x + self.labels = self.pyg_data.y + + # Extract masks + self.train_mask = self.pyg_data.train_mask + self.test_mask = self.pyg_data.test_mask + + # GNNFingers parameters + self.num_fingerprints = num_fingerprints + self.fingerprint_nodes = fingerprint_nodes + self.lambda_threshold = lambda_threshold + + # Initialize components + self.target_gnn = None + self.positive_gnns = [] # Pirated GNNs + self.negative_gnns = [] # Irrelevant GNNs + self.graph_fingerprints = None + self.univerifier = None + + self.fingerprint_lr= fingerprint_lr + self.fingerprint_update_epochs = fingerprint_update_epochs + self.univerifier_update_epochs = univerifier_update_epochs + self.univerifier_lr = univerifier_lr + self.top_k_ratio = top_k_ratio + self.epochs= epochs + + # Move tensors to device + if self.device != 'cpu': + self.graph = self.graph.to(self.device) + self.features = self.features.to(self.device) + self.labels = self.labels.to(self.device) + self.train_mask = self.train_mask.to(self.device) + self.test_mask = self.test_mask.to(self.device) + + def _dgl_to_pyg(self, dgl_graph): + + """Convert DGL graph to PyTorch Geometric Data object""" + # Extract edge indices + edge_index = torch.stack(dgl_graph.edges()) + x = dgl_graph.ndata.get('feat') + y = dgl_graph.ndata.get('label') + + train_mask = dgl_graph.ndata.get('train_mask') + val_mask = dgl_graph.ndata.get('val_mask') + test_mask = dgl_graph.ndata.get('test_mask') + + data = PyGData(x=x, edge_index=edge_index, y=y, + train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) + + return data + + + + # def _create_dataloaders(self, graph_data): + # """Create train and test dataloaders with neighbor sampling""" + # # For DGL graphs + # if hasattr(graph_data, 'ndata'): + # # DGL graph + # sampler = NeighborSampler(self.num_neighbors) + # train_nids = graph_data.ndata['train_mask'].nonzero(as_tuple=True)[0].to(self.device) + # test_nids = graph_data.ndata['test_mask'].nonzero(as_tuple=True)[0].to(self.device) + + # train_collator = NodeCollator(graph_data, train_nids, sampler) + # test_collator = NodeCollator(graph_data, test_nids, sampler) + + # train_dataloader = DataLoader( + # train_collator.dataset, + # batch_size=self.batch_size, + # shuffle=True, + # collate_fn=train_collator.collate, + # drop_last=False + # ) + + # test_dataloader = DataLoader( + # test_collator.dataset, + # batch_size=self.batch_size, + # shuffle=False, + # collate_fn=test_collator.collate, + # drop_last=False + # ) + + # return train_dataloader, test_dataloader + + # else: + # # PyG data + # train_loader = NeighborLoader( + # graph_data, + # num_neighbors=self.num_neighbors, + # batch_size=self.batch_size, + # shuffle=True, + # input_nodes=graph_data.train_mask, + # ) + + # test_loader = NeighborLoader( + # graph_data, + # num_neighbors=self.num_neighbors, + # batch_size=self.batch_size, + # shuffle=False, + # input_nodes=graph_data.test_mask, + # ) + + # return train_loader, test_loader + + def _get_attack_class(self, attack_name): + """Dynamically import and return the specified attack class""" + try: + attack_module = importlib.import_module('models.attack') + attack_class = getattr(attack_module, attack_name) + return attack_class + except (ImportError, AttributeError) as e: + print(f"Error loading attack class '{attack_name}': {e}") + print("Falling back to ModelExtractionAttack0") + attack_module = importlib.import_module('models.attack') + return getattr(attack_module, "ModelExtractionAttack0") + + def defend(self, attack_name=None): + """ + Main defense workflow for GNNFingers + """ + attack_name = attack_name or self.attack_name + AttackClass = self._get_attack_class(attack_name) + print(f"Using attack method: {attack_name}") + + # Step 1: Train target model + print("Training target GNN...") + self.target_gnn = self._train_gnn_model(self.pyg_data, "Target GNN") + + # Step 2: Prepare positive and negative GNNs + print("Preparing positive (pirated) GNNs...") + self.positive_gnns = self._prepare_positive_gnns(self.target_gnn, num_models=50) + + print("Preparing negative (irrelevant) GNNs...") + self.negative_gnns = self._prepare_negative_gnns(num_models=50) + + # Step 3: Initialize graph fingerprints + print("Initializing graph fingerprints...") + self.graph_fingerprints = self._initialize_graph_fingerprints() + + # Step 4: Initialize Univerifier + output_dim = self._get_output_dimension(self.target_gnn, self.graph_fingerprints[0]) + self.univerifier = Univerifier(input_dim=output_dim * self.num_fingerprints) + self.univerifier = self.univerifier.to(self.device) + + # Step 5: Joint learning of fingerprints and Univerifier + print("Joint learning of fingerprints and Univerifier...") + self._joint_learning_alternating() + + # Step 6: Attack target model + print("Attacking target model...") + attack = AttackClass(self.dataset, attack_node_fraction=0.2) + attack_results = attack.attack() + suspect_model = attack.net2 if hasattr(attack, 'net2') else None + + # Step 7: Verify ownership + if suspect_model is not None: + print("Verifying ownership of suspect model...") + verification_result = self._verify_ownership(suspect_model) + print(f"Ownership verification result: {verification_result}") + + return { + "attack_results": attack_results, + "verification_result": verification_result, + "target_accuracy": self._evaluate_model(self.target_gnn, self.pyg_data), + "suspect_accuracy": self._evaluate_model(suspect_model, self.pyg_data) + } + + return {"attack_results": attack_results, "verification_result": "No suspect model found"} + + def _train_gnn_model(self, data, model_name="GNN", epochs=100): + """Train a GNN model on the given data""" + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=128, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + # Forward pass + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + # Backward pass + loss.backward() + optimizer.step() + + # Evaluate ONLY on test nodes + if epoch % 10 == 0: + model.eval() + with torch.no_grad(): + out = model(data.x, data.edge_index) + pred = out[data.test_mask].argmax(dim=1) + correct = (pred == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + acc = correct / total if total > 0 else 0 + + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + print(f"{model_name} trained with accuracy: {best_acc:.4f}") + return best_model + + def _prepare_positive_gnns(self, target_model, num_models=50): + """Prepare pirated GNNs using obfuscation techniques""" + positive_models = [] + + for i in range(num_models): + # Apply different obfuscation techniques + if i % 3 == 0: + # Fine-tuning - fine-tune different numbers of layers + layers_to_finetune = random.randint(1, 3) + model = self._fine_tune_model(copy.deepcopy(target_model), self.pyg_data, + epochs=10, num_layers_to_finetune=layers_to_finetune) + elif i % 3 == 1: + # Partial retraining - retrain different numbers of layers + layers_to_retrain = random.randint(1, 3) + model = self._partial_retrain_model(copy.deepcopy(target_model), self.pyg_data, + epochs=15, num_layers_to_retrain=layers_to_retrain) + else: + # Distillation - use different temperatures and architectures + temperature = random.choice([1.5, 2.0, 3.0, 4.0]) + model = self._distill_model(target_model, self.pyg_data, + epochs=30, temperature=temperature) + + positive_models.append(model) + + return positive_models + + + def _prepare_negative_gnns(self, num_models=50): + """Prepare irrelevant GNNs""" + negative_models = [] + + for i in range(num_models): + # Train from scratch with different architectures or data + if i % 2 == 0: + # Different architecture + model = self._train_different_architecture(self.pyg_data) + else: + # Different training data (subset) + model = self._train_on_subset(self.pyg_data) + + negative_models.append(model) + + return negative_models + + def _fine_tune_model(self, model, data, epochs=10, num_layers_to_finetune=1): + """Fine-tune a model on the same data, but only update the last K layers""" + # Freeze all layers initially + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze the last K layers for fine-tuning + if hasattr(model, 'convs'): + # For models with convs attribute (like GCNConvGNN, GATConvGNN) + total_layers = len(model.convs) + layers_to_finetune = min(num_layers_to_finetune, total_layers) + + for i in range(total_layers - layers_to_finetune, total_layers): + for param in model.convs[i].parameters(): + param.requires_grad = True + else: + # For other model types, try to find the last layers + all_params = list(model.parameters()) + layers_to_finetune = min(num_layers_to_finetune, len(all_params)) + + for param in all_params[-layers_to_finetune:]: + param.requires_grad = True + + # Only optimize parameters that require gradients + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = optim.Adam(trainable_params, lr=0.001) + criterion = nn.CrossEntropyLoss() + + # Count how many parameters are being fine-tuned + num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + num_total = sum(p.numel() for p in model.parameters()) + print(f"Fine-tuning {num_trainable}/{num_total} parameters ({num_trainable/num_total*100:.1f}%)") + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + loss.backward() + optimizer.step() + + # Unfreeze all parameters for future use + for param in model.parameters(): + param.requires_grad = True + + return model + + def _partial_retrain_model(self, model, data, epochs=10, num_layers_to_retrain=2): + """Partially retrain a model with random initialization of K layers before resuming training""" + # Randomly initialize selected K layers + if hasattr(model, 'convs'): + # For models with convs attribute (like GCNConvGNN, GATConvGNN) + total_layers = len(model.convs) + layers_to_retrain = min(num_layers_to_retrain, total_layers) + + # Randomly select K layers to retrain (not necessarily the last ones) + layer_indices = random.sample(range(total_layers), layers_to_retrain) + + print(f"Partially retraining layers: {layer_indices}") + + for idx in layer_indices: + model.convs[idx].reset_parameters() # Random reinitialization + else: + # For other model types, try to find and reset random layers + # This is a fallback approach + all_layers = list(model.children()) + total_layers = len(all_layers) + layers_to_retrain = min(num_layers_to_retrain, total_layers) + + layer_indices = random.sample(range(total_layers), layers_to_retrain) + + for idx in layer_indices: + if hasattr(all_layers[idx], 'reset_parameters'): + all_layers[idx].reset_parameters() + + # Train the entire model (both retrained and original layers) + optimizer = optim.Adam(model.parameters(), lr=0.01) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + loss.backward() + optimizer.step() + + # Track best model + if epoch % 5 == 0: + acc = self._evaluate_model(model, data) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + print(f"Partial retraining completed. Best accuracy: {best_acc:.4f}") + return best_model if best_model is not None else model + def _distill_model(self, teacher_model, data, epochs=30, temperature=2.0): + """Distill knowledge from teacher to student model with different architecture""" + # Create student model with different architecture + if isinstance(teacher_model, GCNConvGNN): + # If teacher is GCN, use GAT as student + student_model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=96, # Different hidden size + out_channels=self.label_number, + num_layers=2, # Different number of layers + heads=3 # Different number of heads + ).to(self.device) + else: + # If teacher is GAT or other, use GCN as student + student_model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, # Different hidden size + out_channels=self.label_number, + num_layers=3 # Different number of layers + ).to(self.device) + + optimizer = optim.Adam(student_model.parameters(), lr=0.01, weight_decay=1e-4) + + # Combined loss: KL divergence for distillation + cross entropy for ground truth + kl_loss = nn.KLDivLoss(reduction='batchmean') + ce_loss = nn.CrossEntropyLoss() + + teacher_model.eval() + + best_acc = 0 + best_student = None + + for epoch in range(epochs): + student_model.train() + optimizer.zero_grad() + + # Get teacher predictions (with temperature scaling) + with torch.no_grad(): + teacher_logits = teacher_model(data.x, data.edge_index) + teacher_probs = F.softmax(teacher_logits / temperature, dim=1) + + # Get student predictions + student_logits = student_model(data.x, data.edge_index) + student_log_probs = F.log_softmax(student_logits / temperature, dim=1) + + # Distillation loss (KL divergence between teacher and student) + distill_loss = kl_loss(student_log_probs[data.train_mask], + teacher_probs[data.train_mask]) * (temperature ** 2) + + # Student's own classification loss + class_loss = ce_loss(student_logits[data.train_mask], + data.y[data.train_mask]) + + # Combined loss (weighted sum) + loss = 0.7 * distill_loss + 0.3 * class_loss + + loss.backward() + optimizer.step() + + # Track best student model + if epoch % 5 == 0: + student_model.eval() + with torch.no_grad(): + out = student_model(data.x, data.edge_index) + pred = out[data.test_mask].argmax(dim=1) + correct = (pred == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + acc = correct / total if total > 0 else 0 + + if acc > best_acc: + best_acc = acc + best_student = copy.deepcopy(student_model) + + print(f"Distillation Epoch {epoch}: Loss={loss.item():.4f}, " + f"Distill={distill_loss.item():.4f}, Class={class_loss.item():.4f}, " + f"Acc={acc:.4f}") + + print(f"Distillation completed. Best student accuracy: {best_acc:.4f}") + return best_student if best_student is not None else student_model + + def _train_different_architecture(self, data): + """Train a model with different architecture for negative GNNs""" + # Use opposite architecture of target model + if isinstance(self.target_gnn, GCNConv): + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2, + heads=4 + ).to(self.device) + else: + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + return self._train_gnn_model_with_data(model, data, epochs=50) + + def _train_on_subset(self, data, subset_ratio=0.7): + """Train on a subset of the data""" + # Create subset mask + num_train = int(data.train_mask.sum().item() * subset_ratio) + subset_mask = torch.zeros_like(data.train_mask) + train_indices = data.train_mask.nonzero(as_tuple=True)[0] + selected_indices = random.sample(range(len(train_indices)), min(num_train, len(train_indices))) + subset_mask[train_indices[selected_indices]] = True + + # Create subset data + subset_data = PyGData( + x=data.x, + edge_index=data.edge_index, + y=data.y, + train_mask=subset_mask, + test_mask=data.test_mask + ) + # Use opposite architecture of target model + if isinstance(self.target_gnn, GCNConv): + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2, + heads=4 + ).to(self.device) + else: + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + return self._train_gnn_model_with_data(model, subset_data, epochs=50) + + def _train_gnn_model_with_data(self, model, data, epochs=100): + """Train a specific model on specific data""" + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + out = model(data.x, data.edge_index) + loss = criterion(out[data.train_mask], data.y[data.train_mask]) + + loss.backward() + optimizer.step() + + if epoch % 10 == 0: + acc = self._evaluate_model(model, data) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + return best_model + + def _initialize_graph_fingerprints(self): + """Initialize graph fingerprints as learnable parameters from PyG Data""" + fingerprints = nn.ModuleList() # Use ModuleList to properly register parameters + feature_dim = self.features.size(1) if self.features is not None else 16 + + for _ in range(self.num_fingerprints): + # Initialize random graph using Erdos-Renyi model + edge_index = erdos_renyi_graph(self.fingerprint_nodes, 0.3) + + # Initialize node features + if self.features is not None: + x = torch.randn(self.fingerprint_nodes, self.features.size(1)) + else: + x = torch.randn(self.fingerprint_nodes, 16) + + # Convert to learnable fingerprint + fingerprint = LearnableGraphFingerprint.from_pyg_data( + x, edge_index, self.fingerprint_nodes, feature_dim + ).to(self.device) + fingerprints.append(fingerprint) + + return fingerprints + + def _get_output_dimension(self, model, fingerprint_data): + """Get the output dimension of a model for a given fingerprint""" + model.eval() + with torch.no_grad(): + output = model(fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device)) + return output.size(1) + + def _verify_ownership(self, suspect_model): + """Verify if a suspect model is pirated from the target model""" + # Get outputs for all fingerprints + target_outputs = [] + suspect_outputs = [] + + for fingerprint in self.graph_fingerprints: + self.target_gnn.eval() + suspect_model.eval() + # Get fingerprint as PyG Data object (without gradient tracking) + fingerprint = fingerprint.to_pyg_data() + + with torch.no_grad(): + target_out = self.target_gnn(fingerprint.x.to(self.device), + fingerprint.edge_index.to(self.device)) + suspect_out = suspect_model(fingerprint.x.to(self.device), + fingerprint.edge_index.to(self.device)) + + target_outputs.append(target_out) + suspect_outputs.append(suspect_out) + + # Concatenate outputs + target_concat = torch.cat(target_outputs, dim=0).view(1, -1) + suspect_concat = torch.cat(suspect_outputs, dim=0).view(1, -1) + + # Get Univerifier prediction + self.univerifier.eval() + with torch.no_grad(): + prediction = self.univerifier(suspect_concat) + confidence = prediction[0, 1].item() # Probability of being pirated + + return confidence > self.lambda_threshold, confidence + + def _evaluate_model(self, model, data): + """Evaluate model accuracy""" + model.eval() + with torch.no_grad(): + out = model(data.x.to(self.device), data.edge_index.to(self.device)) + pred = out.argmax(dim=1) + correct = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + return correct / total if total > 0 else 0 + + def _update_adjacency_discrete(self, fingerprint, grad_adj): + """ + Update discrete adjacency matrix based on gradients + """ + # Get current discrete adjacency + current_adj = fingerprint.get_discrete_adjacency() + + # Get absolute gradient values and flatten + grad_abs = torch.abs(grad_adj) + grad_abs_flat = grad_abs.view(-1) + + # Determine top-K edges to consider for flipping + k = int(self.top_k_ratio * self.fingerprint_nodes * self.fingerprint_nodes) + topk_values, topk_indices = torch.topk(grad_abs_flat, k) + + # Convert flat indices to row, col indices + rows = topk_indices // self.fingerprint_nodes + cols = topk_indices % self.fingerprint_nodes + + # Update edges based on gradient signs + with torch.no_grad(): + for idx in range(k): + row, col = rows[idx], cols[idx] + grad_val = grad_adj[row, col] + + # Current edge existence (0 or 1) + current_edge = current_adj[row, col] + + # Apply update rules: + if current_edge > 0.5 and grad_val <= 0: + # Edge exists and gradient is negative → remove edge + fingerprint.adj_matrix.data[row, col] = 0.0 + elif current_edge < 0.5 and grad_val >= 0: + # Edge doesn't exist and gradient is positive → add edge + fingerprint.adj_matrix.data[row, col] = 1.0 + + def _update_fingerprints_discrete(self, loss, top_k_ratio=0.1): + """ + Update graph fingerprints using gradients + """ + # Compute gradients for all fingerprints + gradients_adj = [] + gradients_x = [] + + for fingerprint in self.graph_fingerprints: + # Compute gradients for adjacency matrix + grad_adj = torch.autograd.grad( + loss, fingerprint.adj_matrix, + retain_graph=True, create_graph=False + )[0] + + # Compute gradients for node features + grad_x = torch.autograd.grad( + loss, fingerprint.x, + retain_graph=True, create_graph=False + )[0] + + gradients_adj.append(grad_adj) + gradients_x.append(grad_x) + + # Update each fingerprint + for i, fingerprint in enumerate(self.graph_fingerprints): + grad_adj = gradients_adj[i] + grad_x = gradients_x[i] + + # Update node features with clipping + with torch.no_grad(): + fingerprint.x.data += self.fingerprint_lr * grad_x + + # Clip node features to reasonable range + if self.features is not None: + min_val = self.features.min().item() + max_val = self.features.max().item() + fingerprint.x.data = torch.clamp(fingerprint.x.data, min_val, max_val) + else: + fingerprint.x.data = torch.clamp(fingerprint.x.data, -3, 3) + + # Update adjacency matrix using discrete strategy + self._update_adjacency_discrete(fingerprint, grad_adj, top_k_ratio) + + + def visualize_fingerprint_evolution(self, epoch): + """Visualize how fingerprints evolve during training""" + if epoch % 20 == 0: # Visualize every 20 epochs + print(f"\n=== Fingerprint Evolution at Epoch {epoch} ===") + + for i, fingerprint in enumerate(self.graph_fingerprints[:2]): # First 2 only + x, edge_index = fingerprint.get_original_components() + current_adj = fingerprint.get_discrete_adjacency() + + # Calculate statistics + num_edges = current_adj.sum().item() + sparsity = 1 - (num_edges / (self.fingerprint_nodes * self.fingerprint_nodes)) + + print(f"Fingerprint {i}: {num_edges} edges, sparsity: {sparsity:.3f}") + + # Feature statistics + feature_mean = x.mean().item() + feature_std = x.std().item() + print(f" Features: mean={feature_mean:.3f}, std={feature_std:.3f}") + + def _joint_learning_alternating(self): + """ + Joint learning with alternating optimization algorithm + """ + + # Prepare all models and labels + all_models = [self.target_gnn] + self.positive_gnns + self.negative_gnns + labels = torch.cat([ + torch.ones(len(self.positive_gnns) + 1), # Target + positive models + torch.zeros(len(self.negative_gnns)) # Negative models + ]).long().to(self.device) + + # Flag to alternate between fingerprint and univerifier updates + update_fingerprints = True + + for epoch in range(self.epochs): + # Forward pass through all models using the actual discrete structure + all_outputs = [] + for model in all_models: + model_outputs = [] + for fingerprint in self.graph_fingerprints: + model.eval() + + # Get fingerprint as PyG Data object (this uses straight-through estimator) + fingerprint_data = fingerprint(return_pyg_data=True) + + with torch.no_grad(): + # Pass through the model + output = model(fingerprint_data.x, fingerprint_data.edge_index) + model_outputs.append(output) + + # Concatenate all fingerprint outputs + concatenated = torch.cat(model_outputs, dim=0).view(1, -1) + all_outputs.append(concatenated) + + # Stack all outputs + all_outputs = torch.cat(all_outputs, dim=0) + + # Univerifier prediction + univerifier_out = self.univerifier(all_outputs) + + # Calculate joint loss + loss = 0 + for i, model in enumerate(all_models): + if i < len(self.positive_gnns) + 1: # Target + positive models + # log o_+(F) and log o_+(F_+) terms + loss += torch.log(univerifier_out[i, 1] + 1e-10) + else: # Negative models + # log o_-(F_-) term + loss += torch.log(1 - univerifier_out[i, 1] + 1e-10) + + loss = -loss # Negative log likelihood (minimize negative log likelihood) + + # Alternating optimization + if update_fingerprints: + # Phase 1: Update fingerprints for e1 epochs + for e in range(self.fingerprint_update_epochs): + self._update_fingerprints_discrete(loss, self.top_k_ratio) + + update_fingerprints = False + print(f"Epoch {epoch}: Updated fingerprints, Loss: {loss.item():.4f}") + + else: + # Phase 2: Update Univerifier for e2 epochs + univerifier_optimizer = optim.Adam(self.univerifier.parameters(), lr=self.univerifier_lr) + + for e in range(self.univerifier_update_epochs): + univerifier_optimizer.zero_grad() + loss.backward(retain_graph=True) + univerifier_optimizer.step() + + update_fingerprints = True + print(f"Epoch {epoch}: Updated Univerifier, Loss: {loss.item():.4f}") + + # Calculate accuracy every 10 epochs + if epoch % 10 == 0: + with torch.no_grad(): + preds = univerifier_out.argmax(dim=1) + acc = (preds == labels).float().mean().item() + + # Calculate true positive and true negative rates + tp_mask = (preds == 1) & (labels == 1) + tn_mask = (preds == 0) & (labels == 0) + + tp_rate = tp_mask.float().mean().item() if (labels == 1).sum() > 0 else 0 + tn_rate = tn_mask.float().mean().item() if (labels == 0).sum() > 0 else 0 + + print(f"Epoch {epoch}, Acc: {acc:.4f}, TP: {tp_rate:.4f}, TN: {tn_rate:.4f}") + + # Visualize fingerprint evolution + if epoch % 20 == 0: + self.visualize_fingerprint_evolution(epoch) + + + def _verify_ownership_detailed(self, suspect_model): + """Detailed verification for debugging purposes only""" + suspect_outputs = [] + target_outputs = [] + + for fingerprint in self.graph_fingerprints: + suspect_model.eval() + self.target_gnn.eval() + + fingerprint_data = fingerprint.to_pyg_data() + + with torch.no_grad(): + suspect_out = suspect_model( + fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device) + ) + suspect_outputs.append(suspect_out) + + target_out = self.target_gnn( + fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device) + ) + target_outputs.append(target_out) + + suspect_concat = torch.cat(suspect_outputs, dim=0).view(1, -1) + target_concat = torch.cat(target_outputs, dim=0).view(1, -1) + + self.univerifier.eval() + with torch.no_grad(): + suspect_prediction = self.univerifier(suspect_concat) + suspect_confidence = suspect_prediction[0, 1].item() + + target_prediction = self.univerifier(target_concat) + target_confidence = target_prediction[0, 1].item() + + output_similarity = F.cosine_similarity(suspect_concat, target_concat).item() + is_pirated = suspect_confidence > self.lambda_threshold + + # Return detailed info for debugging, but main method keeps original interface + return { + 'is_pirated': is_pirated, + 'confidence': suspect_confidence, + 'target_confidence': target_confidence, + 'output_similarity': output_similarity, + 'lambda_threshold': self.lambda_threshold + } + + +# GNN Model Definitions +class GCNConvGNN(nn.Module): + """GCN-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3): + super(GCNConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GCNConv(in_channels, hidden_channels)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + + # Output layer + self.convs.append(GCNConv(hidden_channels, out_channels)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, training=self.training, p=0.5) + return x + +class GATConvGNN(nn.Module): + + """GAT-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, heads=4): + super(GATConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GATConv(in_channels, hidden_channels, heads=heads)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads)) + + # Output layer + self.convs.append(GATConv(hidden_channels * heads, out_channels, heads=1)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.elu(x) + x = F.dropout(x, training=self.training, p=0.6) + return x \ No newline at end of file diff --git a/models/defense/GNN_Fingers_Examples/GNN_Fingers_Multi_Graph_Support.py b/models/defense/GNN_Fingers_Examples/GNN_Fingers_Multi_Graph_Support.py new file mode 100644 index 0000000..3fceaa1 --- /dev/null +++ b/models/defense/GNN_Fingers_Examples/GNN_Fingers_Multi_Graph_Support.py @@ -0,0 +1,1376 @@ +import sys +sys.path.append('.') +import importlib +import numpy as np +from tqdm import tqdm +import copy +import random +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from datasets import Cora, PubMed +from torch_geometric.nn import GCNConv, GATConv, global_mean_pool, global_add_pool +from torch_geometric.utils import erdos_renyi_graph, to_dense_adj, dense_to_sparse +from torch_geometric.loader import NeighborLoader, DataLoader +from torch_geometric.data import Data as PyGData, Batch +from models.defense.base import BaseDefense + + +class BaseGraphFingerprint(nn.Module): + """Base class for all graph fingerprint types""" + def __init__(self, task_type, num_nodes, feature_dim): + super(BaseGraphFingerprint, self).__init__() + self.task_type = task_type + self.num_nodes = num_nodes + self.feature_dim = feature_dim + + def forward(self): + raise NotImplementedError + + def to_pyg_data(self): + raise NotImplementedError + + def get_sampled_outputs(self, model_output): + """Sample outputs based on task type""" + raise NotImplementedError + +class NodeLevelFingerprint(BaseGraphFingerprint): + """Fingerprint for node-level tasks (node classification)""" + def __init__(self, num_nodes, feature_dim): + super(NodeLevelFingerprint, self).__init__('node_level', num_nodes, feature_dim) + + # Initialize node features and adjacency + self.x = nn.Parameter(torch.randn(num_nodes, feature_dim)) + self.adj_matrix = nn.Parameter(torch.zeros(num_nodes, num_nodes)) + + # For node-level tasks, sample outputs from m nodes + self.sample_indices = nn.Parameter( + torch.randint(0, num_nodes, (min(10, num_nodes),)), + requires_grad=False + ) + + def forward(self, return_pyg_data=True): + # Use straight-through estimator for discrete adjacency + adj_binary = (self.adj_matrix > 0.5).float() + adj_binary_st = adj_binary + (self.adj_matrix - self.adj_matrix.detach()) + edge_index, edge_attr = dense_to_sparse(adj_binary_st) + + if return_pyg_data: + return PyGData(x=self.x, edge_index=edge_index, edge_attr=edge_attr) + return self.x, edge_index, adj_binary + + def to_pyg_data(self): + """Convert to PyG Data object without gradient tracking""" + with torch.no_grad(): + adj_binary = (self.adj_matrix > 0.5).float() + edge_index, edge_attr = dense_to_sparse(adj_binary) + return PyGData(x=self.x.detach(), edge_index=edge_index, edge_attr=edge_attr) + + def get_sampled_outputs(self, model_output): + """Sample outputs from specific nodes for verification""" + return model_output[self.sample_indices] + +class EdgeLevelFingerprint(BaseGraphFingerprint): + """Fingerprint for edge-level tasks (link prediction)""" + def __init__(self, num_nodes, feature_dim): + super(EdgeLevelFingerprint, self).__init__('edge_level', num_nodes, feature_dim) + + self.x = nn.Parameter(torch.randn(num_nodes, feature_dim)) + self.adj_matrix = nn.Parameter(torch.zeros(num_nodes, num_nodes)) + + # For edge-level tasks, sample m node pairs + self.sample_pairs = self._initialize_sample_pairs(num_nodes, 8) + + def _initialize_sample_pairs(self, num_nodes, num_pairs): + """Initialize node pairs to sample for edge outputs""" + pairs = [] + for _ in range(num_pairs): + u, v = random.sample(range(num_nodes), 2) + pairs.append([u, v]) + return nn.Parameter(torch.tensor(pairs), requires_grad=False) + + def forward(self, return_pyg_data=True): + adj_binary = (self.adj_matrix > 0.5).float() + adj_binary_st = adj_binary + (self.adj_matrix - self.adj_matrix.detach()) + edge_index, edge_attr = dense_to_sparse(adj_binary_st) + + if return_pyg_data: + return PyGData(x=self.x, edge_index=edge_index, edge_attr=edge_attr) + return self.x, edge_index, adj_binary + + def to_pyg_data(self): + """Convert to PyG Data object without gradient tracking""" + with torch.no_grad(): + adj_binary = (self.adj_matrix > 0.5).float() + edge_index, edge_attr = dense_to_sparse(adj_binary) + return PyGData(x=self.x.detach(), edge_index=edge_index, edge_attr=edge_attr) + + def get_sampled_outputs(self, model_output): + """Sample edge outputs for verification""" + # For link prediction, model_output is an adjacency probability matrix + sampled_outputs = [] + for u, v in self.sample_pairs: + sampled_outputs.append(model_output[u, v]) + return torch.stack(sampled_outputs) + +class GraphLevelFingerprint(BaseGraphFingerprint): + """Fingerprint for graph-level tasks (graph classification, matching)""" + def __init__(self, num_nodes, feature_dim, num_graphs=64): + super(GraphLevelFingerprint, self).__init__('graph_level', num_nodes, feature_dim) + + # We have Multiple Independent Graphs for Graph Level Task + self.graphs = nn.ModuleList([ + SingleGraphFingerprint(num_nodes, feature_dim) for _ in range(num_graphs) + ]) + + def forward(self, return_pyg_data=True): + return [graph(return_pyg_data) for graph in self.graphs] + + def to_pyg_data(self): + """Convert all graphs to PyG Data objects""" + return [graph.to_pyg_data() for graph in self.graphs] + + def get_sampled_outputs(self, model_outputs): + """Return graph-level outputs directly""" + # For graph-level tasks, outputs are already at the graph level + return torch.cat([output.unsqueeze(0) for output in model_outputs]) + +class SingleGraphFingerprint(nn.Module): + """Single graph component for graph-level fingerprints""" + def __init__(self, num_nodes, feature_dim): + super(SingleGraphFingerprint, self).__init__() + self.x = nn.Parameter(torch.randn(num_nodes, feature_dim)) + self.adj_matrix = nn.Parameter(torch.zeros(num_nodes, num_nodes)) + + def forward(self, return_pyg_data=True): + adj_binary = (self.adj_matrix > 0.5).float() + adj_binary_st = adj_binary + (self.adj_matrix - self.adj_matrix.detach()) + edge_index, edge_attr = dense_to_sparse(adj_binary_st) + + if return_pyg_data: + return PyGData(x=self.x, edge_index=edge_index, edge_attr=edge_attr) + return self.x, edge_index, adj_binary + + def to_pyg_data(self): + """Convert to PyG Data object without gradient tracking""" + with torch.no_grad(): + adj_binary = (self.adj_matrix > 0.5).float() + edge_index, edge_attr = dense_to_sparse(adj_binary) + return PyGData(x=self.x.detach(), edge_index=edge_index, edge_attr=edge_attr) + + +class Univerifier(nn.Module): + """ + Unified Verification Mechanism - Binary classifier that takes concatenated outputs + from suspect models and predicts whether they are pirated or irrelevant. + """ + def __init__(self, input_dim, hidden_dims=[128, 64, 32, 16, 8, 4]): + super(Univerifier, self).__init__() + layers = [] + prev_dim = input_dim + + for hidden_dim in hidden_dims: + layers.append(nn.Linear(prev_dim, hidden_dim)) + layers.append(nn.LeakyReLU()) + layers.append(nn.Dropout(0.1)) + prev_dim = hidden_dim + + layers.append(nn.Linear(prev_dim, 2)) + self.classifier = nn.Sequential(*layers) + + def forward(self, x): + return F.softmax(self.classifier(x), dim=1) + + +class GNNFingers(BaseDefense): + """ + GNNFingers: A Fingerprinting Framework for Verifying Ownerships of Graph Neural Networks + with multi-task support for node-level, edge-level, and graph-level tasks + """ + supported_api_types = {"dgl"} + + def __init__(self, dataset, attack_node_fraction=0.2, device=None, attack_name=None, + num_fingerprints=64, fingerprint_nodes=32, lambda_threshold=0.7, + fingerprint_update_epochs=5, univerifier_update_epochs=3, + fingerprint_lr=0.01, univerifier_lr=0.001, top_k_ratio=0.1, + epochs=100, batch_size=32, num_neighbors=[10, 5], + task_type='node_level'): + """ + Initialize GNNFingers defense framework with multi-task support + + Parameters + ---------- + dataset : Dataset + The original dataset containing the graph to defend + attack_node_fraction : float + Fraction of nodes to consider for attack + device : torch.device + Device to run computations on + attack_name : str + Name of the attack class to use + num_fingerprints : int + Number of graph fingerprints to generate + fingerprint_nodes : int + Number of nodes in each fingerprint graph + lambda_threshold : float + Threshold for Univerifier classification + fingerprint_update_epochs: int + Number of Epochs to update Fingerprint + univerifier_update_epochs: int + Number of Epochs to update Univerifier + fingerprint_lr: float + Learning rate for fingerprint update + univerifier_lr: float + Learning rate for Univerifier update + top_k_ratio: float + top k gradients of fingerprint adjacency matrix to select + epochs: int + total number of epochs to run experiment + batch_size : int + Batch size for training + num_neighbors : list + Number of neighbors to sample at each layer + task_type : str + Type of GNN task: 'node_level', 'edge_level', or 'graph_level' + """ + super().__init__(dataset, attack_node_fraction, device) + self.attack_name = attack_name or "ModelExtractionAttack0" + self.dataset = dataset + self.graph = dataset.graph_data + + # Extract dataset properties + self.node_number = dataset.num_nodes + self.feature_number = dataset.num_features + self.label_number = dataset.num_classes + self.attack_node_number = int(self.node_number * attack_node_fraction) + + # Training parameters + self.batch_size = batch_size + self.num_neighbors = num_neighbors + + #Task type + self.task_type = task_type + + # Convert DGL to PyG data + self.pyg_data = self._dgl_to_pyg(self.graph) + + # Extract features and labels + self.features = self.pyg_data.x + self.labels = self.pyg_data.y + + # Extract masks + self.train_mask = self.pyg_data.train_mask + self.test_mask = self.pyg_data.test_mask + + # GNNFingers parameters + self.num_fingerprints = num_fingerprints + self.fingerprint_nodes = fingerprint_nodes + self.lambda_threshold = lambda_threshold + + # Initialize components + self.target_gnn = None + self.positive_gnns = [] # Pirated GNNs + self.negative_gnns = [] # Irrelevant GNNs + self.graph_fingerprints = None + self.univerifier = None + + self.fingerprint_lr = fingerprint_lr + self.fingerprint_update_epochs = fingerprint_update_epochs + self.univerifier_update_epochs = univerifier_update_epochs + self.univerifier_lr = univerifier_lr + self.top_k_ratio = top_k_ratio + self.epochs = epochs + + # Move tensors to device + if self.device != 'cpu': + self.graph = self.graph.to(self.device) + self.features = self.features.to(self.device) + self.labels = self.labels.to(self.device) + self.train_mask = self.train_mask.to(self.device) + self.test_mask = self.test_mask.to(self.device) + + def _dgl_to_pyg(self, dgl_graph): + """Convert DGL graph to PyTorch Geometric Data object""" + # Extract edge indices + edge_index = torch.stack(dgl_graph.edges()) + x = dgl_graph.ndata.get('feat') + y = dgl_graph.ndata.get('label') + + train_mask = dgl_graph.ndata.get('train_mask') + val_mask = dgl_graph.ndata.get('val_mask') + test_mask = dgl_graph.ndata.get('test_mask') + + data = PyGData(x=x, edge_index=edge_index, y=y, + train_mask=train_mask, val_mask=val_mask, test_mask=test_mask) + + return data + + def _create_dataloaders(self, graph_data, batch_size=None, num_neighbors=None): + """ + Create train and test dataloaders for PyG data + + Parameters + ---------- + graph_data : PyG Data + The graph data to create loaders for + batch_size : int, optional + Batch size (defaults to self.batch_size) + num_neighbors : list, optional + Number of neighbors to sample (defaults to self.num_neighbors) + + Returns + ------- + train_loader : NeighborLoader + Training dataloader + test_loader : NeighborLoader + Test dataloader + """ + batch_size = batch_size or self.batch_size + num_neighbors = num_neighbors or self.num_neighbors + + #Different dataloader for graph-level tasks + if self.task_type == 'graph_level': + # For graph-level tasks, use standard DataLoader + train_loader = DataLoader([graph_data], batch_size=self.batch_size, shuffle=True) + test_loader = DataLoader([graph_data], batch_size=self.batch_size, shuffle=False) + else: + # For node/edge level tasks, use NeighborLoader + train_loader = NeighborLoader( + graph_data, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=True, + input_nodes=graph_data.train_mask, + ) + + test_loader = NeighborLoader( + graph_data, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=False, + input_nodes=graph_data.test_mask, + ) + + return train_loader, test_loader + + def defend(self, attack_name=None): + """ + Main defense workflow for GNNFingers with multi-task support + """ + #Validate task-dataset compatibility before starting + #self._validate_task_dataset_compatibility() + attack_name = attack_name or self.attack_name + AttackClass = self._get_attack_class(attack_name) + print(f"Using attack method: {attack_name}") + print(f"Task type: {self.task_type}") + + # Step 1: Train target model + print("Training target GNN...") + self.target_gnn = self._train_gnn_model(self.pyg_data, "Target GNN") + + # Step 2: Prepare positive and negative GNNs + print("Preparing positive (pirated) GNNs...") + self.positive_gnns = self._prepare_positive_gnns(self.target_gnn, num_models=50) + + print("Preparing negative (irrelevant) GNNs...") + self.negative_gnns = self._prepare_negative_gnns(num_models=50) + + # Step 3: Initialize graph fingerprints + print("Initializing graph fingerprints...") + self.graph_fingerprints = self._initialize_graph_fingerprints() + + # Step 4: Initialize Univerifier + output_dim = self._get_output_dimension(self.target_gnn, self.graph_fingerprints[0]) + self.univerifier = Univerifier(input_dim=output_dim * self.num_fingerprints) + self.univerifier = self.univerifier.to(self.device) + + # Step 5: Joint learning of fingerprints and Univerifier + print("Joint learning of fingerprints and Univerifier...") + self._joint_learning_alternating() + + # Step 6: Attack target model + print("Attacking target model...") + attack = AttackClass(self.dataset, attack_node_fraction=0.2) + attack_results = attack.attack() + suspect_model = attack.net2 if hasattr(attack, 'net2') else None + + # Step 7: Verify ownership + if suspect_model is not None: + print("Verifying ownership of suspect model...") + verification_result = self._verify_ownership(suspect_model) + print(f"Ownership verification result: {verification_result}") + + return { + "attack_results": attack_results, + "verification_result": verification_result, + "target_accuracy": self._evaluate_model(self.target_gnn, self.pyg_data), + "suspect_accuracy": self._evaluate_model(suspect_model, self.pyg_data) + } + + return {"attack_results": attack_results, "verification_result": "No suspect model found"} + + + def _validate_task_dataset_compatibility(self): + """ + Validate that the selected task type is compatible with the dataset + """ + print(f"Validating task type '{self.task_type}' with dataset {type(self.dataset).__name__}...") + + if self.task_type == 'graph_level': + # For graph-level tasks, we need to check if this is a graph dataset + if not hasattr(self.pyg_data, 'graph_y') or self.pyg_data.graph_y is None: + raise ValueError( + f"Graph-level task selected but dataset {type(self.dataset).__name__} " + f"appears to be a single-graph dataset. Use node-level or edge-level task instead, " + f"or use a multi-graph dataset like TUDataset for graph classification." + ) + print(" Graph-level task compatible with dataset") + + elif self.task_type == 'edge_level': + # For edge-level tasks, check if we have sufficient edge information + if not hasattr(self.pyg_data, 'edge_index') or self.pyg_data.edge_index.size(1) == 0: + print("⚠ Warning: Edge-level task selected but dataset has limited edge information") + else: + print(f"Edge-level task compatible with dataset ({self.pyg_data.edge_index.size(1)} edges)") + + elif self.task_type == 'node_level': + # For node-level tasks, ensure we have node labels + if not hasattr(self.pyg_data, 'y') or self.pyg_data.y is None: + raise ValueError( + f"Node-level task selected but dataset {type(self.dataset).__name__} " + f"does not have node labels." + ) + print(f"Node-level task compatible with dataset ({self.pyg_data.y.unique().size(0)} classes)") + + else: + raise ValueError(f"Unknown task type: {self.task_type}. Use 'node_level', 'edge_level', or 'graph_level'") + + return True + + def _train_gnn_model(self, data, model_name="GNN", epochs=100): + """Train a GNN model on the given data using batched training""" + # NEW: Different model architectures for different tasks + if self.task_type == 'graph_level': + model = GraphLevelGNN( + in_channels=data.x.size(1), + hidden_channels=128, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + else: + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=128, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + # Create dataloaders using the helper function + train_loader, test_loader = self._create_dataloaders(data) + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + # Forward pass with task-specific handling + if self.task_type == 'graph_level': + out = model(batch.x, batch.edge_index, batch.batch) + else: + out = model(batch.x, batch.edge_index) + + # Loss calculation with task-specific masking + if self.task_type == 'graph_level': + loss = criterion(out, batch.y) + else: + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + # Backward pass + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Evaluate on test set + if epoch % 10 == 0: + test_acc = self._evaluate_model_with_loader(model, test_loader) + + if test_acc > best_acc: + best_acc = test_acc + best_model = copy.deepcopy(model) + + print(f"{model_name} Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={test_acc:.4f}") + + print(f"{model_name} trained with best accuracy: {best_acc:.4f}") + return best_model + + def _evaluate_model_with_loader(self, model, test_loader): + """Evaluate model accuracy using a test dataloader""" + model.eval() + correct = 0 + total = 0 + + with torch.no_grad(): + for batch in test_loader: + batch = batch.to(self.device) + + if self.task_type == 'graph_level': + out = model(batch.x, batch.edge_index, batch.batch) + pred = out.argmax(dim=1) + correct += (pred == batch.y).sum().item() + total += batch.y.size(0) + else: + out = model(batch.x, batch.edge_index) + pred = out.argmax(dim=1) + + # Only count test nodes + test_mask = batch.test_mask if hasattr(batch, 'test_mask') else torch.ones(batch.num_nodes, dtype=bool) + correct += (pred[test_mask] == batch.y[test_mask]).sum().item() + total += test_mask.sum().item() + + return correct / total if total > 0 else 0 + + def _prepare_positive_gnns(self, target_model, num_models=50): + """Prepare pirated GNNs using obfuscation techniques""" + positive_models = [] + + for i in range(num_models): + # Apply different obfuscation techniques + if i % 3 == 0: + # Fine-tuning with batched training + layers_to_finetune = random.randint(1, 3) + model = self._fine_tune_model(copy.deepcopy(target_model), self.pyg_data, + epochs=10, num_layers_to_finetune=layers_to_finetune) + elif i % 3 == 1: + # Partial retraining with batched training + layers_to_retrain = random.randint(1, 3) + model = self._partial_retrain_model(copy.deepcopy(target_model), self.pyg_data, + epochs=15, num_layers_to_retrain=layers_to_retrain) + else: + # Distillation with batched training + temperature = random.choice([1.5, 2.0, 3.0, 4.0]) + model = self._distill_model(target_model, self.pyg_data, + epochs=30, temperature=temperature) + + positive_models.append(model) + + return positive_models + + def _prepare_negative_gnns(self, num_models=50): + """Prepare irrelevant GNNs""" + negative_models = [] + + for i in range(num_models): + # Train from scratch with different architectures or data + if i % 2 == 0: + # Different architecture + model = self._train_different_architecture(self.pyg_data) + else: + # Different training data (subset) + model = self._train_on_subset(self.pyg_data) + + negative_models.append(model) + + return negative_models + + def _fine_tune_model(self, model, data, epochs=10, num_layers_to_finetune=1): + """Fine-tune a model using batched training""" + # Freeze all layers initially + for param in model.parameters(): + param.requires_grad = False + + # Unfreeze the last K layers for fine-tuning + if hasattr(model, 'convs'): + total_layers = len(model.convs) + layers_to_finetune = min(num_layers_to_finetune, total_layers) + + for i in range(total_layers - layers_to_finetune, total_layers): + for param in model.convs[i].parameters(): + param.requires_grad = True + + # Create dataloader using helper function + train_loader, _ = self._create_dataloaders(data) + + # Only optimize parameters that require gradients + trainable_params = filter(lambda p: p.requires_grad, model.parameters()) + optimizer = optim.Adam(trainable_params, lr=0.001) + criterion = nn.CrossEntropyLoss() + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + # Task-specific forward pass + if self.task_type == 'graph_level': + out = model(batch.x, batch.edge_index, batch.batch) + loss = criterion(out, batch.y) + else: + out = model(batch.x, batch.edge_index) + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Unfreeze all parameters for future use + for param in model.parameters(): + param.requires_grad = True + + return model + + def _partial_retrain_model(self, model, data, epochs=10, num_layers_to_retrain=2): + """Partially retrain a model with random initialization of K layers before resuming training""" + # Randomly initialize selected K layers + if hasattr(model, 'convs'): + # For models with convs attribute (like GCNConvGNN, GATConvGNN) + total_layers = len(model.convs) + layers_to_retrain = min(num_layers_to_retrain, total_layers) + + # Randomly select K layers to retrain + layer_indices = random.sample(range(total_layers), layers_to_retrain) + + print(f"Partially retraining layers: {layer_indices}") + + for idx in layer_indices: + model.convs[idx].reset_parameters() # Random reinitialization + + # Train the entire model (both retrained and original layers) + train_loader, test_loader = self._create_dataloaders(data) + optimizer = optim.Adam(model.parameters(), lr=0.01) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + # Task-specific forward pass + if self.task_type == 'graph_level': + out = model(batch.x, batch.edge_index, batch.batch) + loss = criterion(out, batch.y) + else: + out = model(batch.x, batch.edge_index) + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Track best model + if epoch % 5 == 0: + acc = self._evaluate_model_with_loader(model, test_loader) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + print(f"Partial retraining completed. Best accuracy: {best_acc:.4f}") + return best_model if best_model is not None else model + + def _distill_model(self, teacher_model, data, epochs=30, temperature=2.0): + """Distill knowledge using batched training""" + # Create student model with different architecture + if isinstance(teacher_model, (GCNConvGNN, GraphLevelGNN)): + # If teacher is GCN or GraphLevel, use GAT as student + if self.task_type == 'graph_level': + student_model = GraphLevelGNN( + in_channels=data.x.size(1), + hidden_channels=96, + out_channels=self.label_number, + num_layers=2 + ).to(self.device) + else: + student_model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=96, + out_channels=self.label_number, + num_layers=2, + heads=3 + ).to(self.device) + else: + # If teacher is GAT or other, use GCN as student + if self.task_type == 'graph_level': + student_model = GraphLevelGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + else: + student_model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + # Create dataloader using helper function + train_loader, test_loader = self._create_dataloaders(data) + + optimizer = optim.Adam(student_model.parameters(), lr=0.01, weight_decay=1e-4) + + # Combined loss: KL divergence for distillation + cross entropy for ground truth + kl_loss = nn.KLDivLoss(reduction='batchmean') + ce_loss = nn.CrossEntropyLoss() + + teacher_model.eval() + + best_acc = 0 + best_student = None + + for epoch in range(epochs): + student_model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + # Get teacher predictions (with temperature scaling) + with torch.no_grad(): + if self.task_type == 'graph_level': + teacher_logits = teacher_model(batch.x, batch.edge_index, batch.batch) + else: + teacher_logits = teacher_model(batch.x, batch.edge_index) + teacher_probs = F.softmax(teacher_logits / temperature, dim=1) + + # Get student predictions + if self.task_type == 'graph_level': + student_logits = student_model(batch.x, batch.edge_index, batch.batch) + else: + student_logits = student_model(batch.x, batch.edge_index) + student_log_probs = F.log_softmax(student_logits / temperature, dim=1) + + # Distillation loss (KL divergence between teacher and student) + if self.task_type == 'graph_level': + distill_loss = kl_loss(student_log_probs, teacher_probs) * (temperature ** 2) + class_loss = ce_loss(student_logits, batch.y) + else: + distill_loss = kl_loss(student_log_probs[batch.train_mask], + teacher_probs[batch.train_mask]) * (temperature ** 2) + class_loss = ce_loss(student_logits[batch.train_mask], + batch.y[batch.train_mask]) + + # Combined loss (weighted sum) + loss = 0.7 * distill_loss + 0.3 * class_loss + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + # Track best student model + if epoch % 5 == 0: + student_model.eval() + test_acc = self._evaluate_model_with_loader(student_model, test_loader) + + if test_acc > best_acc: + best_acc = test_acc + best_student = copy.deepcopy(student_model) + + print(f"Distillation Epoch {epoch}: Loss={total_loss/len(train_loader):.4f}, Acc={test_acc:.4f}") + + print(f"Distillation completed. Best student accuracy: {best_acc:.4f}") + return best_student if best_student is not None else student_model + + def _train_different_architecture(self, data): + """Train a model with different architecture for negative GNNs""" + # Use opposite architecture of target model + if isinstance(self.target_gnn, (GCNConvGNN, GraphLevelGNN)): + if self.task_type == 'graph_level': + model = GraphLevelGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2 + ).to(self.device) + else: + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2, + heads=4 + ).to(self.device) + else: + if self.task_type == 'graph_level': + model = GraphLevelGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + else: + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3 + ).to(self.device) + + return self._train_gnn_model_with_data(model, data, epochs=50) + + def _train_on_subset(self, data, subset_ratio=0.7): + """Train on a subset of the data""" + # Create subset mask + if self.task_type == 'graph_level': + # For graph-level, we can't easily create subset, so use different architecture + return self._train_different_architecture(data) + else: + # For node/edge level, create subset of training nodes + num_train = int(data.train_mask.sum().item() * subset_ratio) + subset_mask = torch.zeros_like(data.train_mask) + train_indices = data.train_mask.nonzero(as_tuple=True)[0] + selected_indices = random.sample(range(len(train_indices)), min(num_train, len(train_indices))) + subset_mask[train_indices[selected_indices]] = True + + # Create subset data + subset_data = PyGData( + x=data.x, + edge_index=data.edge_index, + y=data.y, + train_mask=subset_mask, + test_mask=data.test_mask + ) + + # Use opposite architecture of target model + if isinstance(self.target_gnn, (GCNConvGNN, GraphLevelGNN)): + model = GCNConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=2 + ).to(self.device) + else: + model = GATConvGNN( + in_channels=data.x.size(1), + hidden_channels=64, + out_channels=self.label_number, + num_layers=3, + heads=4 + ).to(self.device) + + return self._train_gnn_model_with_data(model, subset_data, epochs=50) + + def _train_gnn_model_with_data(self, model, data, epochs=100): + """Train a specific model on specific data using batched training""" + # Create dataloaders using helper function + train_loader, test_loader = self._create_dataloaders(data) + + optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) + criterion = nn.CrossEntropyLoss() + + best_acc = 0 + best_model = None + + for epoch in range(epochs): + model.train() + total_loss = 0 + + for batch in train_loader: + batch = batch.to(self.device) + optimizer.zero_grad() + + # Task-specific forward pass + if self.task_type == 'graph_level': + out = model(batch.x, batch.edge_index, batch.batch) + loss = criterion(out, batch.y) + else: + out = model(batch.x, batch.edge_index) + loss = criterion(out[batch.train_mask], batch.y[batch.train_mask]) + + loss.backward() + optimizer.step() + + total_loss += loss.item() + + if epoch % 10 == 0: + acc = self._evaluate_model_with_loader(model, test_loader) + if acc > best_acc: + best_acc = acc + best_model = copy.deepcopy(model) + + return best_model + + #Task-Specific Fingerprint Initialization + def _initialize_graph_fingerprints(self): + """Initialize task-specific graph fingerprints""" + fingerprints = nn.ModuleList() + feature_dim = self.features.size(1) if self.features is not None else 16 + + for _ in range(self.num_fingerprints): + if self.task_type == 'node_level': + fingerprint = NodeLevelFingerprint( + self.fingerprint_nodes, feature_dim + ).to(self.device) + + elif self.task_type == 'edge_level': + fingerprint = EdgeLevelFingerprint( + self.fingerprint_nodes, feature_dim + ).to(self.device) + + elif self.task_type == 'graph_level': + # For graph-level tasks, use multiple graphs per fingerprint + fingerprint = GraphLevelFingerprint( + self.fingerprint_nodes, feature_dim, num_graphs=3 + ).to(self.device) + + fingerprints.append(fingerprint) + + return fingerprints + + #Task-Specific Output Handling + def _get_model_outputs(self, model, fingerprint): + """Get model outputs based on task type""" + model.eval() + + if self.task_type == 'node_level': + fingerprint_data = fingerprint.to_pyg_data() + output = model(fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device)) + return fingerprint.get_sampled_outputs(output) + + elif self.task_type == 'edge_level': + fingerprint_data = fingerprint.to_pyg_data() + # For link prediction, assume model outputs adjacency probabilities + node_embeddings = model(fingerprint_data.x.to(self.device), + fingerprint_data.edge_index.to(self.device)) + # Simulate edge prediction by dot product of node embeddings + adj_probs = torch.sigmoid(torch.mm(node_embeddings, node_embeddings.t())) + return fingerprint.get_sampled_outputs(adj_probs) + + elif self.task_type == 'graph_level': + # For graph-level tasks, process each graph in the fingerprint + graph_outputs = [] + for graph_data in fingerprint.to_pyg_data(): + # Create batch dimension for single graph + batch = torch.zeros(graph_data.num_nodes, dtype=torch.long, device=self.device) + output = model(graph_data.x.to(self.device), + graph_data.edge_index.to(self.device), + batch) + graph_outputs.append(output) + return fingerprint.get_sampled_outputs(graph_outputs) + + def _get_output_dimension(self, model, fingerprint): + """Get the output dimension for a given fingerprint""" + model.eval() + with torch.no_grad(): + output = self._get_model_outputs(model, fingerprint) + return output.numel() # Total number of elements + + def _verify_ownership(self, suspect_model): + """Verify if a suspect model is pirated from the target model""" + target_outputs = [] + suspect_outputs = [] + + for fingerprint in self.graph_fingerprints: + self.target_gnn.eval() + suspect_model.eval() + + with torch.no_grad(): + target_out = self._get_model_outputs(self.target_gnn, fingerprint) + suspect_out = self._get_model_outputs(suspect_model, fingerprint) + + target_outputs.append(target_out) + suspect_outputs.append(suspect_out) + + # Concatenate all outputs + target_concat = torch.cat(target_outputs, dim=0).view(1, -1) + suspect_concat = torch.cat(suspect_outputs, dim=0).view(1, -1) + + # Get Univerifier prediction + self.univerifier.eval() + with torch.no_grad(): + prediction = self.univerifier(suspect_concat) + confidence = prediction[0, 1].item() # Probability of being pirated + + return confidence > self.lambda_threshold, confidence + + def _evaluate_model(self, model, data): + """Evaluate model accuracy""" + model.eval() + with torch.no_grad(): + if self.task_type == 'graph_level': + batch = torch.zeros(data.num_nodes, dtype=torch.long, device=self.device) + out = model(data.x.to(self.device), data.edge_index.to(self.device), batch) + pred = out.argmax(dim=1) + correct = (pred == data.y).sum().item() + total = data.y.size(0) + else: + out = model(data.x.to(self.device), data.edge_index.to(self.device)) + pred = out.argmax(dim=1) + correct = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() + total = data.test_mask.sum().item() + return correct / total if total > 0 else 0 + + def _update_adjacency_discrete(self, fingerprint, grad_adj): + """ + Update discrete adjacency matrix based on gradients + """ + # Get current discrete adjacency + if self.task_type == 'graph_level': + # For graph-level fingerprints, update each graph's adjacency + for graph_idx, graph in enumerate(fingerprint.graphs): + current_adj = (graph.adj_matrix > 0.5).float() + self._update_single_adjacency(graph, grad_adj[graph_idx] if grad_adj.dim() > 2 else grad_adj) + else: + # For node/edge level fingerprints + current_adj = (fingerprint.adj_matrix > 0.5).float() + self._update_single_adjacency(fingerprint, grad_adj) + + def _update_single_adjacency(self, fingerprint, grad_adj): + """Update adjacency for a single graph""" + current_adj = (fingerprint.adj_matrix > 0.5).float() + + # Get absolute gradient values and flatten + grad_abs = torch.abs(grad_adj) + grad_abs_flat = grad_abs.view(-1) + + # Determine top-K edges to consider for flipping + k = int(self.top_k_ratio * self.fingerprint_nodes * self.fingerprint_nodes) + topk_values, topk_indices = torch.topk(grad_abs_flat, k) + + # Convert flat indices to row, col indices + rows = topk_indices // self.fingerprint_nodes + cols = topk_indices % self.fingerprint_nodes + + # Update edges based on gradient signs + with torch.no_grad(): + for idx in range(k): + row, col = rows[idx], cols[idx] + grad_val = grad_adj[row, col] + + # Current edge existence (0 or 1) + current_edge = current_adj[row, col] + + # Apply update rules: + if current_edge > 0.5 and grad_val <= 0: + # Edge exists and gradient is negative → remove edge + fingerprint.adj_matrix.data[row, col] = 0.0 + elif current_edge < 0.5 and grad_val >= 0: + # Edge doesn't exist and gradient is positive → add edge + fingerprint.adj_matrix.data[row, col] = 1.0 + + def _update_fingerprints_discrete(self, loss, top_k_ratio=0.1): + """ + Update graph fingerprints using gradients + """ + # Compute gradients for all fingerprints + gradients_adj = [] + gradients_x = [] + + for fingerprint in self.graph_fingerprints: + if self.task_type == 'graph_level': + # For graph-level, we need to handle multiple graphs + grad_adj_list = [] + grad_x_list = [] + for graph in fingerprint.graphs: + grad_adj = torch.autograd.grad( + loss, graph.adj_matrix, + retain_graph=True, create_graph=False + )[0] + grad_x = torch.autograd.grad( + loss, graph.x, + retain_graph=True, create_graph=False + )[0] + grad_adj_list.append(grad_adj) + grad_x_list.append(grad_x) + gradients_adj.append(torch.stack(grad_adj_list)) + gradients_x.append(torch.stack(grad_x_list)) + else: + # For node/edge level + grad_adj = torch.autograd.grad( + loss, fingerprint.adj_matrix, + retain_graph=True, create_graph=False + )[0] + grad_x = torch.autograd.grad( + loss, fingerprint.x, + retain_graph=True, create_graph=False + )[0] + gradients_adj.append(grad_adj) + gradients_x.append(grad_x) + + # Update each fingerprint + for i, fingerprint in enumerate(self.graph_fingerprints): + grad_adj = gradients_adj[i] + grad_x = gradients_x[i] + + if self.task_type == 'graph_level': + # Update each graph in the fingerprint + for graph_idx, graph in enumerate(fingerprint.graphs): + with torch.no_grad(): + graph.x.data += self.fingerprint_lr * grad_x[graph_idx] + # Clip node features + if self.features is not None: + min_val = self.features.min().item() + max_val = self.features.max().item() + graph.x.data = torch.clamp(graph.x.data, min_val, max_val) + else: + graph.x.data = torch.clamp(graph.x.data, -3, 3) + + # Update adjacency + self._update_adjacency_discrete(fingerprint, grad_adj) + else: + # Update node features with clipping + with torch.no_grad(): + fingerprint.x.data += self.fingerprint_lr * grad_x + + # Clip node features to reasonable range + if self.features is not None: + min_val = self.features.min().item() + max_val = self.features.max().item() + fingerprint.x.data = torch.clamp(fingerprint.x.data, min_val, max_val) + else: + fingerprint.x.data = torch.clamp(fingerprint.x.data, -3, 3) + + # Update adjacency matrix using discrete strategy + self._update_adjacency_discrete(fingerprint, grad_adj) + + def visualize_fingerprint_evolution(self, epoch): + """Visualize how fingerprints evolve during training""" + if epoch % 20 == 0: # Visualize every 20 epochs + print(f"\n=== Fingerprint Evolution at Epoch {epoch} ===") + + for i, fingerprint in enumerate(self.graph_fingerprints[:2]): # First 2 only + if self.task_type == 'graph_level': + print(f"Graph-Level Fingerprint {i}: {len(fingerprint.graphs)} graphs") + for graph_idx, graph in enumerate(fingerprint.graphs): + current_adj = (graph.adj_matrix > 0.5).float() + num_edges = current_adj.sum().item() + sparsity = 1 - (num_edges / (self.fingerprint_nodes * self.fingerprint_nodes)) + print(f" Graph {graph_idx}: {num_edges} edges, sparsity: {sparsity:.3f}") + else: + if self.task_type == 'node_level': + current_adj = (fingerprint.adj_matrix > 0.5).float() + num_edges = current_adj.sum().item() + sparsity = 1 - (num_edges / (self.fingerprint_nodes * self.fingerprint_nodes)) + print(f"Node-Level Fingerprint {i}: {num_edges} edges, sparsity: {sparsity:.3f}") + else: + current_adj = (fingerprint.adj_matrix > 0.5).float() + num_edges = current_adj.sum().item() + sparsity = 1 - (num_edges / (self.fingerprint_nodes * self.fingerprint_nodes)) + print(f"Edge-Level Fingerprint {i}: {num_edges} edges, sparsity: {sparsity:.3f}") + + def _joint_learning_alternating(self): + """ + Joint learning with alternating optimization algorithm + """ + + # Prepare all models and labels + all_models = [self.target_gnn] + self.positive_gnns + self.negative_gnns + labels = torch.cat([ + torch.ones(len(self.positive_gnns) + 1), # Target + positive models + torch.zeros(len(self.negative_gnns)) # Negative models + ]).long().to(self.device) + + # Flag to alternate between fingerprint and univerifier updates + update_fingerprints = True + + for epoch in range(self.epochs): + # Forward pass through all models + all_outputs = [] + for model in all_models: + model_outputs = [] + for fingerprint in self.graph_fingerprints: + model.eval() + + # Get model outputs with task-specific handling + output = self._get_model_outputs(model, fingerprint) + model_outputs.append(output) + + # Concatenate all fingerprint outputs + concatenated = torch.cat(model_outputs, dim=0).view(1, -1) + all_outputs.append(concatenated) + + # Stack all outputs + all_outputs = torch.cat(all_outputs, dim=0) + + # Univerifier prediction + univerifier_out = self.univerifier(all_outputs) + + # Calculate joint loss + loss = 0 + for i, model in enumerate(all_models): + if i < len(self.positive_gnns) + 1: # Target + positive models + # log o_+(F) and log o_+(F_+) terms + loss += torch.log(univerifier_out[i, 1] + 1e-10) + else: # Negative models + # log o_-(F_-) term + loss += torch.log(1 - univerifier_out[i, 1] + 1e-10) + + loss = -loss # Negative log likelihood + + # Alternating optimization + if update_fingerprints: + # Phase 1: Update fingerprints for e1 epochs + for e in range(self.fingerprint_update_epochs): + self._update_fingerprints_discrete(loss, self.top_k_ratio) + + update_fingerprints = False + print(f"Epoch {epoch}: Updated fingerprints, Loss: {loss.item():.4f}") + + else: + # Phase 2: Update Univerifier for e2 epochs + univerifier_optimizer = optim.Adam(self.univerifier.parameters(), lr=self.univerifier_lr) + + for e in range(self.univerifier_update_epochs): + univerifier_optimizer.zero_grad() + loss.backward(retain_graph=True) + univerifier_optimizer.step() + + update_fingerprints = True + print(f"Epoch {epoch}: Updated Univerifier, Loss: {loss.item():.4f}") + + # Calculate accuracy every 10 epochs + if epoch % 10 == 0: + with torch.no_grad(): + preds = univerifier_out.argmax(dim=1) + acc = (preds == labels).float().mean().item() + + # Calculate true positive and true negative rates + tp_mask = (preds == 1) & (labels == 1) + tn_mask = (preds == 0) & (labels == 0) + + tp_rate = tp_mask.float().mean().item() if (labels == 1).sum() > 0 else 0 + tn_rate = tn_mask.float().mean().item() if (labels == 0).sum() > 0 else 0 + + print(f"Epoch {epoch}, Acc: {acc:.4f}, TP: {tp_rate:.4f}, TN: {tn_rate:.4f}") + + # Visualize fingerprint evolution + if epoch % 20 == 0: + self.visualize_fingerprint_evolution(epoch) + + +#Graph-Level GNN Model +class GraphLevelGNN(nn.Module): + """GNN model for graph-level tasks with global pooling""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3, pool_type='mean'): + super(GraphLevelGNN, self).__init__() + self.convs = nn.ModuleList() + self.pool_type = pool_type + + # Input layer + self.convs.append(GCNConv(in_channels, hidden_channels)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + + # Output layer + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + + # Final classification layer + self.classifier = nn.Linear(hidden_channels, out_channels) + + def forward(self, x, edge_index, batch): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, training=self.training, p=0.5) + + # Global pooling + if self.pool_type == 'mean': + x = global_mean_pool(x, batch) + else: # sum pooling + x = global_add_pool(x, batch) + + # Final classification + return self.classifier(x) + +class GCNConvGNN(nn.Module): + """GCN-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3): + super(GCNConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GCNConv(in_channels, hidden_channels)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_channels, hidden_channels)) + + # Output layer + self.convs.append(GCNConv(hidden_channels, out_channels)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, training=self.training, p=0.5) + return x + +class GATConvGNN(nn.Module): + """GAT-based GNN model""" + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, heads=4): + super(GATConvGNN, self).__init__() + self.convs = nn.ModuleList() + + # Input layer + self.convs.append(GATConv(in_channels, hidden_channels, heads=heads)) + + # Hidden layers + for _ in range(num_layers - 2): + self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads)) + + # Output layer + self.convs.append(GATConv(hidden_channels * heads, out_channels, heads=1)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i < len(self.convs) - 1: + x = F.elu(x) + x = F.dropout(x, training=self.training, p=0.6) + return x + + +if __name__ == "__main__": + # Set device + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {device}") + + # Load dataset + dataset = Cora(api_type="dgl", path="./data") + print(f"Loaded dataset: {dataset}") + + # Initialize defense with multi-graph support + defense = GNNFingers( + dataset=dataset, + device=device, + num_fingerprints=32, + fingerprint_nodes=64, + epochs=100, + task_type='node_level' # Change to 'edge_level' or 'graph_level' for different tasks + ) + + results = defense.defend() + + # Print results + print("\n=== Defense Results ===") + print(f"Target Accuracy: {results.get('target_accuracy', 0):.4f}") + print(f"Suspect Accuracy: {results.get('suspect_accuracy', 0):.4f}") + print(f"Verification Result: {results.get('verification_result', 'Unknown')}") \ No newline at end of file