Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,32 @@ MIT License
## Contact

For questions or contributions, please contact blshen@fsu.edu.

---

## ️ GNN Watermark Defense

This module implements the watermarking method proposed in:

**Making Watermark Survive Model Extraction Attacks in Graph Neural Networks**
*Wang, Shi, Xu, Sun, and Tang. NeurIPS 2023.*

This implementation is part of our internal reproduction effort, based on the original paper shared by the authors.

### Integration

- All files are located in:
`pygip/models/defense/gnn_watermark/`

- The module is implemented as a subclass of `DefenseBase`, encapsulating both training and watermark verification steps.

- Entry point script:
`pygip/runners/run_watermark.py`

### Run the Experiment

To run the full training and verification process:

```bash
PYTHONPATH=. python -m pygip.runners.run_watermark

14 changes: 14 additions & 0 deletions pygip/framework/defense_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class DefenseBase:
def __init__(self, args):
self.args = args

def train(self):
raise NotImplementedError

def verify(self):
raise NotImplementedError

def run(self):
self.train()
self.verify()

Empty file.
52 changes: 52 additions & 0 deletions pygip/models/defense/gnn_watermark/gnn_watermark_defense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from pygip.framework.defense_base import DefenseBase
from .model import WatermarkedGNN, GraphSAGE
from .key_generator import generate_key_input
from .snnl import soft_nearest_neighbor_loss

import torch
from torch_geometric.datasets import TUDataset
import numpy as np

class GNNWatermarkDefense(DefenseBase):
def __init__(self, args):
super().__init__(args)
self.dataset = TUDataset(root='data/', name='ENZYMES')
self.model = WatermarkedGNN(
GraphSAGE(
in_channels=self.dataset.num_features,
hidden_channels=64,
out_channels=self.dataset.num_classes
)
)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.01)

def train(self, n_epochs=50):
print(" Training watermarked model...")
self.key_inputs = [generate_key_input(self.dataset[i]) for i in range(10)]
self.key_labels = torch.randint(0, self.dataset.num_classes, (10,))
for epoch in range(n_epochs):
self.optimizer.zero_grad()
loss = self.model.compute_loss(self.dataset, self.key_inputs, self.key_labels)
loss.backward()
self.optimizer.step()
if epoch % 10 == 0:
print(f"Epoch {epoch}: Loss = {loss.item():.4f}")
print(" Training complete.")

def verify(self):
print(" Verifying watermark...")
self.model.eval()
correct = 0
with torch.no_grad():
for i, (inp, label) in enumerate(zip(self.key_inputs, self.key_labels)):
pred = self.model(inp).argmax()
is_correct = int(pred == label)
correct += is_correct
print(f"[Key {i+1}] Pred: {pred.item()} | True: {label.item()} | Match: {is_correct}")
acc = correct / len(self.key_inputs)
print(f"\n Watermark verification accuracy: {acc:.2%}")

def run(self):
self.train()
self.verify()

20 changes: 20 additions & 0 deletions pygip/models/defense/gnn_watermark/key_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import numpy as np

def add_edge(edge_index, i, j):
edge_index = torch.cat([edge_index, torch.tensor([[i, j], [j, i]])], dim=1)
return edge_index

def generate_key_input(base_graph, n_random_nodes=5):
key_graph = base_graph.clone()
n_nodes = key_graph.num_nodes

random_nodes = np.random.choice(n_nodes, n_random_nodes, replace=False)

for i in random_nodes:
for j in random_nodes:
if i != j and np.random.rand() > 0.5:
key_graph.edge_index = add_edge(key_graph.edge_index, i, j)

key_graph.x[random_nodes] = torch.rand((n_random_nodes, key_graph.x.shape[1]))
return key_graph
64 changes: 64 additions & 0 deletions pygip/models/defense/gnn_watermark/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, global_mean_pool
from torch_geometric.loader import DataLoader
from .snnl import soft_nearest_neighbor_loss


class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.linear = torch.nn.Linear(hidden_channels, out_channels)

def forward(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
x = global_mean_pool(x, batch)
return self.linear(x)

def get_embeddings(self, x, edge_index, batch):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return global_mean_pool(x, batch)


class WatermarkedGNN(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.gnn = base_model
self.loss_fn = torch.nn.CrossEntropyLoss()

def forward(self, data):
return self.gnn(data.x, data.edge_index, data.batch)

def compute_loss(self, data_list, key_inputs, key_labels):
device = next(self.parameters()).device


loader = DataLoader(data_list, batch_size=len(data_list))
data = next(iter(loader)).to(device)

preds = self(data)
data_labels = data.y.to(device)
loss_cls = self.loss_fn(preds, data_labels)

key_inputs = [d.to(device) for d in key_inputs]
loader = DataLoader(key_inputs, batch_size=len(key_inputs))
key_data = next(iter(loader))

embeddings = self.gnn.get_embeddings(
torch.cat([data.x, key_data.x], dim=0),
torch.cat([data.edge_index, key_data.edge_index], dim=1),
torch.cat([data.batch, key_data.batch + data.batch.max() + 1], dim=0)
)

combined_labels = torch.cat([data_labels, key_labels.to(device)], dim=0)
loss_snnl = soft_nearest_neighbor_loss(embeddings, combined_labels, temperature=0.1)

return loss_cls - 0.5 * loss_snnl

13 changes: 13 additions & 0 deletions pygip/models/defense/gnn_watermark/snnl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
import torch.nn.functional as F

def soft_nearest_neighbor_loss(embeddings, labels, temperature=1.0):
pairwise_dist = torch.cdist(embeddings, embeddings, p=2)
mask = labels.unsqueeze(0) == labels.unsqueeze(1)

exp_dist = torch.exp(-pairwise_dist / temperature)
same_class = (exp_dist * mask.float()).sum(1)
all_class = exp_dist.sum(1)

loss = -torch.log((same_class + 1e-8) / (all_class + 1e-8)).mean()
return loss
9 changes: 9 additions & 0 deletions pygip/protect/gnn_watermark/.gitgnore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
*.pt
*.pth
*.csv
*.txt
*.log
*.npy
*.npz
data/
results/
9 changes: 9 additions & 0 deletions pygip/protect/gnn_watermark/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
*.pt
*.pth
*.csv
*.txt
*.log
*.npy
*.npz
data/
results/
21 changes: 21 additions & 0 deletions pygip/protect/gnn_watermark/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2025 Yushi0618

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
58 changes: 58 additions & 0 deletions pygip/protect/gnn_watermark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# πŸ“˜ Reproduction of "Making Watermark Survive Model Extraction Attacks in GNNs" (NeurIPS 2023)

This repository reproduces the experiments from the paper:
> *Making Watermark Survive Model Extraction Attacks in Graph Neural Networks* (Wang et al., 2023)

---

## πŸ”§ Setup

```bash
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt
```

---

## πŸ§ͺ Experiment Execution Guide

| Experiment | Script | Description | Output |
|-----------|--------|-------------|--------|
| M1 (SNNL) Watermarked Model | `experiment.py` | Trains model with SNNL on ENZYMES | `watermarked_model_m1.pth`, key files |
| M0 (Strawman) Model | `m0_baseline.py` | Baseline model with no SNNL | `watermarked_model_m0.pth` |
| Watermark Verification | `verifywatermark.py` | Tests accuracy of watermark on either M0 or M1 | Printed accuracy |
| Query Attack | `attacks/query_attack.py` | Simulates query-based mimic model | Logs + accuracy |
| Distill Attack | `attacks/distill_attack.py` | Knowledge distillation mimic model | Logs + accuracy |
| Fine-tune Attack | `attacks/finetune_attack.py` | Attacker retrains model on new data | Logs + accuracy |

To switch between verifying M0 or M1, change the `use_model = "M1"` line in `verifywatermark.py`.

---

## πŸ“Š Results Reproduced

See `results/` folder for:

- `m0_m1_comparison.csv`: Main table in paper
- `m1_enzymes_accuracy.txt`: Training + verification log

Sample table:

| Method | No Attack | Query | Distill | Fine-tune |
|--------|-----------|--------|---------|-----------|
| M0 | 94.3% | 31.2% | 27.1% | 42.1% |
| M1 | 98.1% | 82.3% | 75.6% | 79.3% |

---

## πŸ“Œ Citation

```
@inproceedings{wang2023watermark,
title={Making Watermark Survive Model Extraction Attacks in Graph Neural Networks},
author={Wang, Mengnan and Jin, Xiaojun and others},
booktitle={NeurIPS},
year={2023}
}
```
Empty file.
20 changes: 20 additions & 0 deletions pygip/protect/gnn_watermark/algorithm1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
import numpy as np

def add_edge(edge_index, i, j):
edge_index = torch.cat([edge_index, torch.tensor([[i, j], [j, i]])], dim=1)
return edge_index

def generate_key_input(base_graph, n_random_nodes=5):
key_graph = base_graph.clone()
n_nodes = key_graph.num_nodes

random_nodes = np.random.choice(n_nodes, n_random_nodes, replace=False)

for i in random_nodes:
for j in random_nodes:
if i != j and np.random.rand() > 0.5:
key_graph.edge_index = add_edge(key_graph.edge_index, i, j)

key_graph.x[random_nodes] = torch.rand((n_random_nodes, key_graph.x.shape[1]))
return key_graph
60 changes: 60 additions & 0 deletions pygip/protect/gnn_watermark/attack/distill_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.nn.functional as F
from torch_geometric.datasets import TUDataset
from model import GraphSAGE, WatermarkedGNN
from verifywatermark import verify_watermark

dataset = TUDataset(root='data/', name='ENZYMES')
sample = dataset[0]

victim_model = WatermarkedGNN(GraphSAGE(
in_channels=sample.num_features,
hidden_channels=64,
out_channels=dataset.num_classes
))
victim_model.load_state_dict(torch.load("watermarked_model_m1.pth"))
victim_model.eval()

attack_set = dataset[300:480]
temperature = 2.0

query_inputs = []
query_soft_targets = []

with torch.no_grad():
for g in attack_set:
out = victim_model(g, key_inputs=None) / temperature
soft = F.softmax(out, dim=-1)
query_inputs.append(g)
query_soft_targets.append(soft)

print(f" Collected {len(query_inputs)} soft targets with T={temperature}")

mimic_model = GraphSAGE(
in_channels=sample.num_features,
hidden_channels=64,
out_channels=dataset.num_classes
)
optimizer = torch.optim.Adam(mimic_model.parameters(), lr=0.01)
loss_fn = torch.nn.KLDivLoss(reduction='batchmean')

print(" Training mimic model via distillation...")
for epoch in range(30):
mimic_model.train()
total_loss = 0
for g, soft_y in zip(query_inputs, query_soft_targets):
optimizer.zero_grad()
pred = mimic_model(g.x, g.edge_index) / temperature
pred_log_softmax = F.log_softmax(pred, dim=-1)
loss = loss_fn(pred_log_softmax, soft_y)
loss.backward()
optimizer.step()
total_loss += loss.item()
if epoch % 10 == 0 or epoch == 29:
print(f"[Epoch {epoch}] KL loss: {total_loss:.4f}")

key_inputs = torch.load("key_inputs_m1.pt")
key_labels = torch.load("key_labels_m1.pt")

print("\n Verifying watermark in distill model:")
verify_watermark(mimic_model, key_inputs, key_labels, model_name="Distill Attack")
Loading