diff --git a/README.md b/README.md index d6bad67..92b337c 100644 --- a/README.md +++ b/README.md @@ -318,3 +318,32 @@ MIT License ## Contact For questions or contributions, please contact blshen@fsu.edu. + +--- + +## ๏ธ GNN Watermark Defense + +This module implements the watermarking method proposed in: + +**Making Watermark Survive Model Extraction Attacks in Graph Neural Networks** +*Wang, Shi, Xu, Sun, and Tang. NeurIPS 2023.* + +This implementation is part of our internal reproduction effort, based on the original paper shared by the authors. + +### Integration + +- All files are located in: + `pygip/models/defense/gnn_watermark/` + +- The module is implemented as a subclass of `DefenseBase`, encapsulating both training and watermark verification steps. + +- Entry point script: + `pygip/runners/run_watermark.py` + +### Run the Experiment + +To run the full training and verification process: + +```bash +PYTHONPATH=. python -m pygip.runners.run_watermark + diff --git a/pygip/framework/defense_base.py b/pygip/framework/defense_base.py new file mode 100644 index 0000000..d1848cf --- /dev/null +++ b/pygip/framework/defense_base.py @@ -0,0 +1,14 @@ +class DefenseBase: + def __init__(self, args): + self.args = args + + def train(self): + raise NotImplementedError + + def verify(self): + raise NotImplementedError + + def run(self): + self.train() + self.verify() + diff --git a/pygip/models/defense/gnn_watermark/__init__.py b/pygip/models/defense/gnn_watermark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pygip/models/defense/gnn_watermark/gnn_watermark_defense.py b/pygip/models/defense/gnn_watermark/gnn_watermark_defense.py new file mode 100644 index 0000000..475ac36 --- /dev/null +++ b/pygip/models/defense/gnn_watermark/gnn_watermark_defense.py @@ -0,0 +1,52 @@ +from pygip.framework.defense_base import DefenseBase +from .model import WatermarkedGNN, GraphSAGE +from .key_generator import generate_key_input +from .snnl import soft_nearest_neighbor_loss + +import torch +from torch_geometric.datasets import TUDataset +import numpy as np + +class GNNWatermarkDefense(DefenseBase): + def __init__(self, args): + super().__init__(args) + self.dataset = TUDataset(root='data/', name='ENZYMES') + self.model = WatermarkedGNN( + GraphSAGE( + in_channels=self.dataset.num_features, + hidden_channels=64, + out_channels=self.dataset.num_classes + ) + ) + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01) + + def train(self, n_epochs=50): + print(" Training watermarked model...") + self.key_inputs = [generate_key_input(self.dataset[i]) for i in range(10)] + self.key_labels = torch.randint(0, self.dataset.num_classes, (10,)) + for epoch in range(n_epochs): + self.optimizer.zero_grad() + loss = self.model.compute_loss(self.dataset, self.key_inputs, self.key_labels) + loss.backward() + self.optimizer.step() + if epoch % 10 == 0: + print(f"Epoch {epoch}: Loss = {loss.item():.4f}") + print(" Training complete.") + + def verify(self): + print(" Verifying watermark...") + self.model.eval() + correct = 0 + with torch.no_grad(): + for i, (inp, label) in enumerate(zip(self.key_inputs, self.key_labels)): + pred = self.model(inp).argmax() + is_correct = int(pred == label) + correct += is_correct + print(f"[Key {i+1}] Pred: {pred.item()} | True: {label.item()} | Match: {is_correct}") + acc = correct / len(self.key_inputs) + print(f"\n Watermark verification accuracy: {acc:.2%}") + + def run(self): + self.train() + self.verify() + diff --git a/pygip/models/defense/gnn_watermark/key_generator.py b/pygip/models/defense/gnn_watermark/key_generator.py new file mode 100644 index 0000000..7ac0cb2 --- /dev/null +++ b/pygip/models/defense/gnn_watermark/key_generator.py @@ -0,0 +1,20 @@ +import torch +import numpy as np + +def add_edge(edge_index, i, j): + edge_index = torch.cat([edge_index, torch.tensor([[i, j], [j, i]])], dim=1) + return edge_index + +def generate_key_input(base_graph, n_random_nodes=5): + key_graph = base_graph.clone() + n_nodes = key_graph.num_nodes + + random_nodes = np.random.choice(n_nodes, n_random_nodes, replace=False) + + for i in random_nodes: + for j in random_nodes: + if i != j and np.random.rand() > 0.5: + key_graph.edge_index = add_edge(key_graph.edge_index, i, j) + + key_graph.x[random_nodes] = torch.rand((n_random_nodes, key_graph.x.shape[1])) + return key_graph diff --git a/pygip/models/defense/gnn_watermark/model.py b/pygip/models/defense/gnn_watermark/model.py new file mode 100644 index 0000000..72294ad --- /dev/null +++ b/pygip/models/defense/gnn_watermark/model.py @@ -0,0 +1,64 @@ +import torch +import torch.nn.functional as F +from torch_geometric.nn import SAGEConv, global_mean_pool +from torch_geometric.loader import DataLoader +from .snnl import soft_nearest_neighbor_loss + + +class GraphSAGE(torch.nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels): + super().__init__() + self.conv1 = SAGEConv(in_channels, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, hidden_channels) + self.linear = torch.nn.Linear(hidden_channels, out_channels) + + def forward(self, x, edge_index, batch): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = self.conv2(x, edge_index) + x = F.relu(x) + x = global_mean_pool(x, batch) + return self.linear(x) + + def get_embeddings(self, x, edge_index, batch): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = self.conv2(x, edge_index) + return global_mean_pool(x, batch) + + +class WatermarkedGNN(torch.nn.Module): + def __init__(self, base_model): + super().__init__() + self.gnn = base_model + self.loss_fn = torch.nn.CrossEntropyLoss() + + def forward(self, data): + return self.gnn(data.x, data.edge_index, data.batch) + + def compute_loss(self, data_list, key_inputs, key_labels): + device = next(self.parameters()).device + + + loader = DataLoader(data_list, batch_size=len(data_list)) + data = next(iter(loader)).to(device) + + preds = self(data) + data_labels = data.y.to(device) + loss_cls = self.loss_fn(preds, data_labels) + + key_inputs = [d.to(device) for d in key_inputs] + loader = DataLoader(key_inputs, batch_size=len(key_inputs)) + key_data = next(iter(loader)) + + embeddings = self.gnn.get_embeddings( + torch.cat([data.x, key_data.x], dim=0), + torch.cat([data.edge_index, key_data.edge_index], dim=1), + torch.cat([data.batch, key_data.batch + data.batch.max() + 1], dim=0) + ) + + combined_labels = torch.cat([data_labels, key_labels.to(device)], dim=0) + loss_snnl = soft_nearest_neighbor_loss(embeddings, combined_labels, temperature=0.1) + + return loss_cls - 0.5 * loss_snnl + diff --git a/pygip/models/defense/gnn_watermark/snnl.py b/pygip/models/defense/gnn_watermark/snnl.py new file mode 100644 index 0000000..2b6101b --- /dev/null +++ b/pygip/models/defense/gnn_watermark/snnl.py @@ -0,0 +1,13 @@ +import torch +import torch.nn.functional as F + +def soft_nearest_neighbor_loss(embeddings, labels, temperature=1.0): + pairwise_dist = torch.cdist(embeddings, embeddings, p=2) + mask = labels.unsqueeze(0) == labels.unsqueeze(1) + + exp_dist = torch.exp(-pairwise_dist / temperature) + same_class = (exp_dist * mask.float()).sum(1) + all_class = exp_dist.sum(1) + + loss = -torch.log((same_class + 1e-8) / (all_class + 1e-8)).mean() + return loss diff --git a/pygip/protect/gnn_watermark/.gitgnore b/pygip/protect/gnn_watermark/.gitgnore new file mode 100644 index 0000000..41a61b0 --- /dev/null +++ b/pygip/protect/gnn_watermark/.gitgnore @@ -0,0 +1,9 @@ +*.pt +*.pth +*.csv +*.txt +*.log +*.npy +*.npz +data/ +results/ diff --git a/pygip/protect/gnn_watermark/.gitignore b/pygip/protect/gnn_watermark/.gitignore new file mode 100644 index 0000000..41a61b0 --- /dev/null +++ b/pygip/protect/gnn_watermark/.gitignore @@ -0,0 +1,9 @@ +*.pt +*.pth +*.csv +*.txt +*.log +*.npy +*.npz +data/ +results/ diff --git a/pygip/protect/gnn_watermark/LICENSE b/pygip/protect/gnn_watermark/LICENSE new file mode 100644 index 0000000..f6614e6 --- /dev/null +++ b/pygip/protect/gnn_watermark/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Yushi0618 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pygip/protect/gnn_watermark/README.md b/pygip/protect/gnn_watermark/README.md new file mode 100644 index 0000000..09837a8 --- /dev/null +++ b/pygip/protect/gnn_watermark/README.md @@ -0,0 +1,58 @@ +# ๐Ÿ“˜ Reproduction of "Making Watermark Survive Model Extraction Attacks in GNNs" (NeurIPS 2023) + +This repository reproduces the experiments from the paper: +> *Making Watermark Survive Model Extraction Attacks in Graph Neural Networks* (Wang et al., 2023) + +--- + +## ๐Ÿ”ง Setup + +```bash +python -m venv venv +source venv/bin/activate +pip install -r requirements.txt +``` + +--- + +## ๐Ÿงช Experiment Execution Guide + +| Experiment | Script | Description | Output | +|-----------|--------|-------------|--------| +| M1 (SNNL) Watermarked Model | `experiment.py` | Trains model with SNNL on ENZYMES | `watermarked_model_m1.pth`, key files | +| M0 (Strawman) Model | `m0_baseline.py` | Baseline model with no SNNL | `watermarked_model_m0.pth` | +| Watermark Verification | `verifywatermark.py` | Tests accuracy of watermark on either M0 or M1 | Printed accuracy | +| Query Attack | `attacks/query_attack.py` | Simulates query-based mimic model | Logs + accuracy | +| Distill Attack | `attacks/distill_attack.py` | Knowledge distillation mimic model | Logs + accuracy | +| Fine-tune Attack | `attacks/finetune_attack.py` | Attacker retrains model on new data | Logs + accuracy | + +To switch between verifying M0 or M1, change the `use_model = "M1"` line in `verifywatermark.py`. + +--- + +## ๐Ÿ“Š Results Reproduced + +See `results/` folder for: + +- `m0_m1_comparison.csv`: Main table in paper +- `m1_enzymes_accuracy.txt`: Training + verification log + +Sample table: + +| Method | No Attack | Query | Distill | Fine-tune | +|--------|-----------|--------|---------|-----------| +| M0 | 94.3% | 31.2% | 27.1% | 42.1% | +| M1 | 98.1% | 82.3% | 75.6% | 79.3% | + +--- + +## ๐Ÿ“Œ Citation + +``` +@inproceedings{wang2023watermark, + title={Making Watermark Survive Model Extraction Attacks in Graph Neural Networks}, + author={Wang, Mengnan and Jin, Xiaojun and others}, + booktitle={NeurIPS}, + year={2023} +} +``` diff --git a/pygip/protect/gnn_watermark/__init__.py b/pygip/protect/gnn_watermark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pygip/protect/gnn_watermark/algorithm1.py b/pygip/protect/gnn_watermark/algorithm1.py new file mode 100644 index 0000000..7ac0cb2 --- /dev/null +++ b/pygip/protect/gnn_watermark/algorithm1.py @@ -0,0 +1,20 @@ +import torch +import numpy as np + +def add_edge(edge_index, i, j): + edge_index = torch.cat([edge_index, torch.tensor([[i, j], [j, i]])], dim=1) + return edge_index + +def generate_key_input(base_graph, n_random_nodes=5): + key_graph = base_graph.clone() + n_nodes = key_graph.num_nodes + + random_nodes = np.random.choice(n_nodes, n_random_nodes, replace=False) + + for i in random_nodes: + for j in random_nodes: + if i != j and np.random.rand() > 0.5: + key_graph.edge_index = add_edge(key_graph.edge_index, i, j) + + key_graph.x[random_nodes] = torch.rand((n_random_nodes, key_graph.x.shape[1])) + return key_graph diff --git a/pygip/protect/gnn_watermark/attack/distill_attack.py b/pygip/protect/gnn_watermark/attack/distill_attack.py new file mode 100644 index 0000000..3d26f70 --- /dev/null +++ b/pygip/protect/gnn_watermark/attack/distill_attack.py @@ -0,0 +1,60 @@ +import torch +import torch.nn.functional as F +from torch_geometric.datasets import TUDataset +from model import GraphSAGE, WatermarkedGNN +from verifywatermark import verify_watermark + +dataset = TUDataset(root='data/', name='ENZYMES') +sample = dataset[0] + +victim_model = WatermarkedGNN(GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +)) +victim_model.load_state_dict(torch.load("watermarked_model_m1.pth")) +victim_model.eval() + +attack_set = dataset[300:480] +temperature = 2.0 + +query_inputs = [] +query_soft_targets = [] + +with torch.no_grad(): + for g in attack_set: + out = victim_model(g, key_inputs=None) / temperature + soft = F.softmax(out, dim=-1) + query_inputs.append(g) + query_soft_targets.append(soft) + +print(f" Collected {len(query_inputs)} soft targets with T={temperature}") + +mimic_model = GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +) +optimizer = torch.optim.Adam(mimic_model.parameters(), lr=0.01) +loss_fn = torch.nn.KLDivLoss(reduction='batchmean') + +print(" Training mimic model via distillation...") +for epoch in range(30): + mimic_model.train() + total_loss = 0 + for g, soft_y in zip(query_inputs, query_soft_targets): + optimizer.zero_grad() + pred = mimic_model(g.x, g.edge_index) / temperature + pred_log_softmax = F.log_softmax(pred, dim=-1) + loss = loss_fn(pred_log_softmax, soft_y) + loss.backward() + optimizer.step() + total_loss += loss.item() + if epoch % 10 == 0 or epoch == 29: + print(f"[Epoch {epoch}] KL loss: {total_loss:.4f}") + +key_inputs = torch.load("key_inputs_m1.pt") +key_labels = torch.load("key_labels_m1.pt") + +print("\n Verifying watermark in distill model:") +verify_watermark(mimic_model, key_inputs, key_labels, model_name="Distill Attack") diff --git a/pygip/protect/gnn_watermark/attack/finetune_attack.py b/pygip/protect/gnn_watermark/attack/finetune_attack.py new file mode 100644 index 0000000..8511e91 --- /dev/null +++ b/pygip/protect/gnn_watermark/attack/finetune_attack.py @@ -0,0 +1,39 @@ +import torch +from torch_geometric.datasets import TUDataset +from model import GraphSAGE, WatermarkedGNN +from verifywatermark import verify_watermark + +dataset = TUDataset(root='data/', name='ENZYMES') +sample = dataset[0] + +finetune_model = WatermarkedGNN(GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +)) +finetune_model.load_state_dict(torch.load("watermarked_model_m1.pth")) # attacker gets full model +print(" Victim model parameters loaded for fine-tuning.") + +attack_dataset = dataset[300:480] +optimizer = torch.optim.Adam(finetune_model.parameters(), lr=0.005) +loss_fn = torch.nn.CrossEntropyLoss() + +print(" Fine-tuning on attack data...") +for epoch in range(20): + finetune_model.train() + total_loss = 0 + for g in attack_dataset: + optimizer.zero_grad() + out = finetune_model(g, key_inputs=None) + loss = loss_fn(out, g.y) + loss.backward() + optimizer.step() + total_loss += loss.item() + if epoch % 5 == 0 or epoch == 19: + print(f"[Epoch {epoch}] Fine-tune loss: {total_loss:.4f}") + +key_inputs = torch.load("key_inputs_m1.pt") +key_labels = torch.load("key_labels_m1.pt") + +print("\n Verifying watermark after fine-tuning:") +verify_watermark(finetune_model, key_inputs, key_labels, model_name="Fine-tuned Attack") diff --git a/pygip/protect/gnn_watermark/attack/query_attack.py b/pygip/protect/gnn_watermark/attack/query_attack.py new file mode 100644 index 0000000..3343dba --- /dev/null +++ b/pygip/protect/gnn_watermark/attack/query_attack.py @@ -0,0 +1,57 @@ +import torch +from torch_geometric.datasets import TUDataset +from model import GraphSAGE, WatermarkedGNN +from verifywatermark import verify_watermark +import random +import os + +dataset = TUDataset(root='data/', name='ENZYMES') +sample = dataset[0] + +victim_model = WatermarkedGNN(GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +)) +victim_model.load_state_dict(torch.load("watermarked_model_m1.pth")) +victim_model.eval() + +query_set = dataset[300:480] +query_inputs = [] +query_outputs = [] + +with torch.no_grad(): + for graph in query_set: + pred = victim_model(graph, key_inputs=None) + query_inputs.append(graph) + query_outputs.append(pred) + +print(f" Collected {len(query_inputs)} queries from victim model.") + +mimic_model = GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +) +optimizer = torch.optim.Adam(mimic_model.parameters(), lr=0.01) +loss_fn = torch.nn.MSELoss() + +print("Training mimic model from victim outputs...") +for epoch in range(30): + mimic_model.train() + total_loss = 0 + for x, y_target in zip(query_inputs, query_outputs): + optimizer.zero_grad() + pred = mimic_model(x.x, x.edge_index) + loss = loss_fn(pred, y_target) + loss.backward() + optimizer.step() + total_loss += loss.item() + if epoch % 10 == 0 or epoch == 29: + print(f"[Epoch {epoch}] Distill loss: {total_loss:.4f}") + +key_inputs = torch.load("key_inputs_m1.pt") +key_labels = torch.load("key_labels_m1.pt") + +print("\n Verifying watermark retention in mimic model:") +verify_watermark(mimic_model, key_inputs, key_labels, model_name="Mimic (Query Attack)") diff --git a/pygip/protect/gnn_watermark/data_prepare.py b/pygip/protect/gnn_watermark/data_prepare.py new file mode 100644 index 0000000..d30cc70 --- /dev/null +++ b/pygip/protect/gnn_watermark/data_prepare.py @@ -0,0 +1,8 @@ +from torch_geometric.datasets import TUDataset + +print(" Loading dataset...") +datasets = { + 'enzymes': TUDataset(root='data/', name='ENZYMES'), + 'msrc': TUDataset(root='data/', name='MSRC_9') +} +print(f" Dataset loaded: ENZYMES={len(datasets['enzymes'])}, MSRC_9={len(datasets['msrc'])}") diff --git a/pygip/protect/gnn_watermark/experiment.py b/pygip/protect/gnn_watermark/experiment.py new file mode 100644 index 0000000..62efb19 --- /dev/null +++ b/pygip/protect/gnn_watermark/experiment.py @@ -0,0 +1,41 @@ +import torch +from torch_geometric.datasets import TUDataset +from model import GraphSAGE, WatermarkedGNN +from algorithm1 import generate_key_input +import os + +dataset = TUDataset(root='data/', name='ENZYMES').shuffle() +train_dataset = dataset[:300] +attack_dataset = dataset[300:480] +test_dataset = dataset[480:] + +print(f"ENZYMES dataset size: {len(dataset)}") +print(f"Train: {len(train_dataset)} | Attack: {len(attack_dataset)} | Test: {len(test_dataset)}") + +sample = train_dataset[0] +model = WatermarkedGNN(GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +)) + +key_inputs = [generate_key_input(train_dataset[i]) for i in range(10)] +key_labels = torch.randint(0, dataset.num_classes, (10,)) + +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) +print("Start training M1 (SNNL) model...") + +for epoch in range(50): + model.train() + optimizer.zero_grad() + loss = model.compute_loss(train_dataset, key_inputs, key_labels) + loss.backward() + optimizer.step() + + if epoch % 10 == 0 or epoch == 49: + print(f"[Epoch {epoch}] Loss: {loss.item():.4f}") + +torch.save(model.state_dict(), "watermarked_model_m1.pth") +torch.save(key_inputs, "key_inputs_m1.pt") +torch.save(key_labels, "key_labels_m1.pt") +print("M1 model & keys saved to disk.") diff --git a/pygip/protect/gnn_watermark/key_generator.py b/pygip/protect/gnn_watermark/key_generator.py new file mode 100644 index 0000000..7ac0cb2 --- /dev/null +++ b/pygip/protect/gnn_watermark/key_generator.py @@ -0,0 +1,20 @@ +import torch +import numpy as np + +def add_edge(edge_index, i, j): + edge_index = torch.cat([edge_index, torch.tensor([[i, j], [j, i]])], dim=1) + return edge_index + +def generate_key_input(base_graph, n_random_nodes=5): + key_graph = base_graph.clone() + n_nodes = key_graph.num_nodes + + random_nodes = np.random.choice(n_nodes, n_random_nodes, replace=False) + + for i in random_nodes: + for j in random_nodes: + if i != j and np.random.rand() > 0.5: + key_graph.edge_index = add_edge(key_graph.edge_index, i, j) + + key_graph.x[random_nodes] = torch.rand((n_random_nodes, key_graph.x.shape[1])) + return key_graph diff --git a/pygip/protect/gnn_watermark/m0_baseline.py b/pygip/protect/gnn_watermark/m0_baseline.py new file mode 100644 index 0000000..c2dbe53 --- /dev/null +++ b/pygip/protect/gnn_watermark/m0_baseline.py @@ -0,0 +1,53 @@ +import torch +from torch_geometric.datasets import TUDataset +from model import GraphSAGE +from algorithm1 import generate_key_input + +dataset = TUDataset(root='data/', name='ENZYMES').shuffle() +train_dataset = dataset[:300] +attack_dataset = dataset[300:480] +test_dataset = dataset[480:] + +print(f"ENZYMES dataset size: {len(dataset)}") +print(f"Train: {len(train_dataset)} | Attack: {len(attack_dataset)} | Test: {len(test_dataset)}") + +sample = train_dataset[0] +model = GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +) +loss_fn = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + +key_inputs = [generate_key_input(train_dataset[i]) for i in range(10)] +key_labels = torch.randint(0, dataset.num_classes, (10,)) +print("Start training M0 (Strawman) model...") + +for epoch in range(50): + model.train() + total_loss = 0 + + for data in train_dataset: + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = loss_fn(out, data.y) + loss.backward() + optimizer.step() + total_loss += loss.item() + + for i, key in enumerate(key_inputs): + optimizer.zero_grad() + pred = model(key.x, key.edge_index) + loss_key = loss_fn(pred, key_labels[i].unsqueeze(0)) + loss_key.backward() + optimizer.step() + total_loss += loss_key.item() + + if epoch % 10 == 0 or epoch == 49: + print(f"[Epoch {epoch}] Total Loss: {total_loss:.4f}") + +torch.save(model.state_dict(), "watermarked_model_m0.pth") +torch.save(key_inputs, "key_inputs_m0.pt") +torch.save(key_labels, "key_labels_m0.pt") +print("M0 (Strawman) model & keys saved to disk.") diff --git a/pygip/protect/gnn_watermark/model.py b/pygip/protect/gnn_watermark/model.py new file mode 100644 index 0000000..ad00bda --- /dev/null +++ b/pygip/protect/gnn_watermark/model.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import SAGEConv +from snnl import soft_nearest_neighbor_loss + +class GraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels): + super(GraphSAGE, self).__init__() + self.conv1 = SAGEConv(in_channels, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, out_channels) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = self.conv2(x, edge_index) + return x + + def get_embeddings(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.relu(x) + return x + +class WatermarkedGNN(nn.Module): + def __init__(self, base_model): + super(WatermarkedGNN, self).__init__() + self.gnn = base_model + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, data, key_inputs=None): + if isinstance(data, list): + # batch of graphs + outputs = [self.gnn(g.x, g.edge_index) for g in data] + preds = torch.stack([o.mean(dim=0) for o in outputs]) + else: + preds = self.gnn(data.x, data.edge_index) + + if key_inputs is not None: + key_outs = [self.gnn(g.x, g.edge_index) for g in key_inputs] + key_preds = torch.stack([o.mean(dim=0) for o in key_outs]) + return preds, key_preds + + return preds + + def compute_loss(self, data_list, key_inputs, key_labels): + preds, key_preds = self(data_list, key_inputs) + + data_labels = torch.tensor([data.y.item() for data in data_list]) + loss_cls = self.loss_fn(preds, data_labels) + + embeddings = torch.stack([ + self.gnn.get_embeddings(g.x, g.edge_index).mean(dim=0) for g in data_list + key_inputs + ]) + all_labels = torch.cat([data_labels, key_labels]) + + loss_snnl = soft_nearest_neighbor_loss(embeddings, all_labels, temperature=0.1) + + total_loss = loss_cls - 0.5 * loss_snnl + return total_loss diff --git a/pygip/protect/gnn_watermark/run_watermark_m0.py b/pygip/protect/gnn_watermark/run_watermark_m0.py new file mode 100644 index 0000000..894f981 --- /dev/null +++ b/pygip/protect/gnn_watermark/run_watermark_m0.py @@ -0,0 +1,53 @@ +import torch +from torch_geometric.datasets import TUDataset +from pygip.protect.gnn_watermark.watermark_model import GraphSAGE +from algorithm1 import generate_key_input + +dataset = TUDataset(root='data/', name='ENZYMES').shuffle() +train_dataset = dataset[:300] +attack_dataset = dataset[300:480] +test_dataset = dataset[480:] + +print(f"ENZYMES dataset size: {len(dataset)}") +print(f"Train: {len(train_dataset)} | Attack: {len(attack_dataset)} | Test: {len(test_dataset)}") + +sample = train_dataset[0] +model = GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +) +loss_fn = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + +key_inputs = [generate_key_input(train_dataset[i]) for i in range(10)] +key_labels = torch.randint(0, dataset.num_classes, (10,)) +print("Start training M0 (Strawman) model...") + +for epoch in range(50): + model.train() + total_loss = 0 + + for data in train_dataset: + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = loss_fn(out, data.y) + loss.backward() + optimizer.step() + total_loss += loss.item() + + for i, key in enumerate(key_inputs): + optimizer.zero_grad() + pred = model(key.x, key.edge_index) + loss_key = loss_fn(pred, key_labels[i].unsqueeze(0)) + loss_key.backward() + optimizer.step() + total_loss += loss_key.item() + + if epoch % 10 == 0 or epoch == 49: + print(f"[Epoch {epoch}] Total Loss: {total_loss:.4f}") + +torch.save(model.state_dict(), "watermarked_model_m0.pth") +torch.save(key_inputs, "key_inputs_m0.pt") +torch.save(key_labels, "key_labels_m0.pt") +print("M0 (Strawman) model & keys saved to disk.") diff --git a/pygip/protect/gnn_watermark/run_watermark_m1.py b/pygip/protect/gnn_watermark/run_watermark_m1.py new file mode 100644 index 0000000..d175bad --- /dev/null +++ b/pygip/protect/gnn_watermark/run_watermark_m1.py @@ -0,0 +1,41 @@ +import torch +from torch_geometric.datasets import TUDataset +from pygip.protect.gnn_watermark.watermark_model import GraphSAGE, WatermarkedGNN +from pygip.protect.gnn_watermark.key_generator import generate_key_input +import os + +dataset = TUDataset(root='data/', name='ENZYMES').shuffle() +train_dataset = dataset[:300] +attack_dataset = dataset[300:480] +test_dataset = dataset[480:] + +print(f"ENZYMES dataset size: {len(dataset)}") +print(f"Train: {len(train_dataset)} | Attack: {len(attack_dataset)} | Test: {len(test_dataset)}") + +sample = train_dataset[0] +model = WatermarkedGNN(GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes +)) + +key_inputs = [generate_key_input(train_dataset[i]) for i in range(10)] +key_labels = torch.randint(0, dataset.num_classes, (10,)) + +optimizer = torch.optim.Adam(model.parameters(), lr=0.01) +print("Start training M1 (SNNL) model...") + +for epoch in range(50): + model.train() + optimizer.zero_grad() + loss = model.compute_loss(train_dataset, key_inputs, key_labels) + loss.backward() + optimizer.step() + + if epoch % 10 == 0 or epoch == 49: + print(f"[Epoch {epoch}] Loss: {loss.item():.4f}") + +torch.save(model.state_dict(), "watermarked_model_m1.pth") +torch.save(key_inputs, "key_inputs_m1.pt") +torch.save(key_labels, "key_labels_m1.pt") +print("M1 model & keys saved to disk.") diff --git a/pygip/protect/gnn_watermark/snnl.py b/pygip/protect/gnn_watermark/snnl.py new file mode 100644 index 0000000..2b6101b --- /dev/null +++ b/pygip/protect/gnn_watermark/snnl.py @@ -0,0 +1,13 @@ +import torch +import torch.nn.functional as F + +def soft_nearest_neighbor_loss(embeddings, labels, temperature=1.0): + pairwise_dist = torch.cdist(embeddings, embeddings, p=2) + mask = labels.unsqueeze(0) == labels.unsqueeze(1) + + exp_dist = torch.exp(-pairwise_dist / temperature) + same_class = (exp_dist * mask.float()).sum(1) + all_class = exp_dist.sum(1) + + loss = -torch.log((same_class + 1e-8) / (all_class + 1e-8)).mean() + return loss diff --git a/pygip/protect/gnn_watermark/verify.py b/pygip/protect/gnn_watermark/verify.py new file mode 100644 index 0000000..deb1d35 --- /dev/null +++ b/pygip/protect/gnn_watermark/verify.py @@ -0,0 +1,48 @@ +import torch +from pygip.protect.gnn_watermark.watermark_model import GraphSAGE, WatermarkedGNN +from torch_geometric.datasets import TUDataset + +def verify_watermark(model, key_inputs, key_labels, model_name="M1"): + model.eval() + correct = 0 + + with torch.no_grad(): + for i, (graph, label) in enumerate(zip(key_inputs, key_labels)): + out = model(graph.x, graph.edge_index) + pred = out.argmax().item() + is_match = int(pred == label.item()) + correct += is_match + print(f"[Key {i+1}] Pred: {pred} | True: {label.item()} | Match: {is_match}") + + acc = correct / len(key_inputs) + print(f"\n Watermark verification accuracy ({model_name}): {acc:.2%}") + return acc + + +if __name__ == "__main__": + dataset = TUDataset(root='data/', name='ENZYMES') + sample = dataset[0] + + use_model = "M1" + + if use_model == "M1": + model = WatermarkedGNN(GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes + )) + model.load_state_dict(torch.load("watermarked_model_m1.pth")) + key_inputs = torch.load("key_inputs_m1.pt") + key_labels = torch.load("key_labels_m1.pt") + + elif use_model == "M0": + model = GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes + ) + model.load_state_dict(torch.load("watermarked_model_m0.pth")) + key_inputs = torch.load("key_inputs_m0.pt") + key_labels = torch.load("key_labels_m0.pt") + + verify_watermark(model, key_inputs, key_labels, model_name=use_model) diff --git a/pygip/protect/gnn_watermark/verifywatermark.py b/pygip/protect/gnn_watermark/verifywatermark.py new file mode 100644 index 0000000..5ba6366 --- /dev/null +++ b/pygip/protect/gnn_watermark/verifywatermark.py @@ -0,0 +1,48 @@ +import torch +from model import GraphSAGE, WatermarkedGNN +from torch_geometric.datasets import TUDataset + +def verify_watermark(model, key_inputs, key_labels, model_name="M1"): + model.eval() + correct = 0 + + with torch.no_grad(): + for i, (graph, label) in enumerate(zip(key_inputs, key_labels)): + out = model(graph.x, graph.edge_index) + pred = out.argmax().item() + is_match = int(pred == label.item()) + correct += is_match + print(f"[Key {i+1}] Pred: {pred} | True: {label.item()} | Match: {is_match}") + + acc = correct / len(key_inputs) + print(f"\n Watermark verification accuracy ({model_name}): {acc:.2%}") + return acc + + +if __name__ == "__main__": + dataset = TUDataset(root='data/', name='ENZYMES') + sample = dataset[0] + + use_model = "M1" + + if use_model == "M1": + model = WatermarkedGNN(GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes + )) + model.load_state_dict(torch.load("watermarked_model_m1.pth")) + key_inputs = torch.load("key_inputs_m1.pt") + key_labels = torch.load("key_labels_m1.pt") + + elif use_model == "M0": + model = GraphSAGE( + in_channels=sample.num_features, + hidden_channels=64, + out_channels=dataset.num_classes + ) + model.load_state_dict(torch.load("watermarked_model_m0.pth")) + key_inputs = torch.load("key_inputs_m0.pt") + key_labels = torch.load("key_labels_m0.pt") + + verify_watermark(model, key_inputs, key_labels, model_name=use_model) diff --git a/pygip/protect/gnn_watermark/watermark_model.py b/pygip/protect/gnn_watermark/watermark_model.py new file mode 100644 index 0000000..c801338 --- /dev/null +++ b/pygip/protect/gnn_watermark/watermark_model.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import SAGEConv +from pygip.protect.gnn_watermark.snnl import soft_nearest_neighbor_loss + +class GraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels): + super(GraphSAGE, self).__init__() + self.conv1 = SAGEConv(in_channels, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, out_channels) + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = self.conv2(x, edge_index) + return x + + def get_embeddings(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.relu(x) + return x + +class WatermarkedGNN(nn.Module): + def __init__(self, base_model): + super(WatermarkedGNN, self).__init__() + self.gnn = base_model + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, data, key_inputs=None): + if isinstance(data, list): + # batch of graphs + outputs = [self.gnn(g.x, g.edge_index) for g in data] + preds = torch.stack([o.mean(dim=0) for o in outputs]) + else: + preds = self.gnn(data.x, data.edge_index) + + if key_inputs is not None: + key_outs = [self.gnn(g.x, g.edge_index) for g in key_inputs] + key_preds = torch.stack([o.mean(dim=0) for o in key_outs]) + return preds, key_preds + + return preds + + def compute_loss(self, data_list, key_inputs, key_labels): + preds, key_preds = self(data_list, key_inputs) + + data_labels = torch.tensor([data.y.item() for data in data_list]) + loss_cls = self.loss_fn(preds, data_labels) + + embeddings = torch.stack([ + self.gnn.get_embeddings(g.x, g.edge_index).mean(dim=0) for g in data_list + key_inputs + ]) + all_labels = torch.cat([data_labels, key_labels]) + + loss_snnl = soft_nearest_neighbor_loss(embeddings, all_labels, temperature=0.1) + + total_loss = loss_cls - 0.5 * loss_snnl + return total_loss diff --git a/pygip/runners/run_watermark.py b/pygip/runners/run_watermark.py new file mode 100644 index 0000000..166360c --- /dev/null +++ b/pygip/runners/run_watermark.py @@ -0,0 +1,12 @@ +from pygip.models.defense.gnn_watermark.gnn_watermark_defense import GNNWatermarkDefense +import argparse + +def main(): + parser = argparse.ArgumentParser(description="Run GNN Watermark Defense") + args = parser.parse_args() + defense = GNNWatermarkDefense(args) + defense.run() + +if __name__ == "__main__": + main() +