diff --git a/README.md b/README.md index 2e8261f..6fda9a7 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,105 @@ +# PyGIP Installation Guide + +PyGIP supports multiple CUDA versions and provides two installation methods. Choose the method that best suits your needs. + +## Method 1: Direct Installation + +Create and activate a new conda environment: ```bash -# pip install conda create -n pygip python=3.10.14 conda activate pygip -# if you use cuda 11.x +``` + +### Choose your CUDA version: + +#### For CUDA 11.x users: +```bash pip install pygip -f https://data.dgl.ai/wheels/torch-2.3/cu118/repo.html --extra-index-url https://download.pytorch.org/whl/cu118 -# if you use cuda 12.x -# pip install pygip -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html --extra-index-url https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html ``` +#### For CUDA 12.x users: +```bash +pip install pygip -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html --extra-index-url https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html +``` + +## Method 2: Environment Setup + +This method uses a predefined environment.yml file and is recommended for development: +1. Create and activate the environment: ```bash -# Simple setup. conda env create -f environment.yml -n pygip conda activate pygip -pip install dgl -f https://data.dgl.ai/wheels/repo.html #due to dgl issues, unfortunately we have to install this dgl 2.2.1 manually. +``` -# Under the GNNIP directory +2. Install DGL manually (required due to DGL 2.2.1 dependency issues): +```bash +pip install dgl -f https://data.dgl.ai/wheels/repo.html +``` + +3. Set up the Python path (run this from the PyGIP root directory): +```bash +# Linux/Mac: export PYTHONPATH=`pwd` -# Quick testing -python3 examples/examples.py +# Windows: +set PYTHONPATH=%cd% ``` +4. Test the installation: +```bash +python examples/examples.py +``` + +## Verifying CUDA Setup + +To verify your CUDA installation is working correctly: +```python +import torch +print("CUDA Available:", torch.cuda.is_available()) +print("CUDA Version:", torch.version.cuda) +print("GPU Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU found") +``` + +## Troubleshooting + +If you encounter CUDA-related issues: + +1. Ensure your NVIDIA drivers are up to date: +```bash +nvidia-smi +``` + +2. If you need to reinstall PyTorch with a specific CUDA version: +```bash +# Remove existing torch installation +pip uninstall torch torch-geometric -y + +# For CUDA 11.x: +pip install torch --index-url https://download.pytorch.org/whl/cu118 +pip install torch-geometric==2.5.0 + +# For CUDA 12.x: +pip install torch --index-url https://download.pytorch.org/whl/cu121 +pip install torch-geometric==2.5.0 +``` + +3. Verify DGL installation: +```bash +python -c "import dgl; print(dgl.__version__)" +``` + +## Requirements + +PyGIP has been tested with the following core dependencies: +- Python 3.10.14 +- PyTorch 2.3.0 +- torch-geometric 2.5.0 +- DGL 2.2.1 + +For a complete list of dependencies, see the `requirements.txt` file in the repository.y + + # Attack ## Model Extraction Attacks against Graph Neural Network diff --git a/environment.yml b/environment.yml index f81ef88..e90e3cd 100644 --- a/environment.yml +++ b/environment.yml @@ -27,7 +27,7 @@ dependencies: - markupsafe==2.1.5 - mpmath==1.3.0 - networkx==3.3 - - numpy==2.0.1 + - numpy>=1.23.5,<2.0.0 - pandas==2.2.2 - psutil==6.0.0 - pydantic==2.8.2 diff --git a/pygip/data_free_attack/README.md b/pygip/data_free_attack/README.md new file mode 100644 index 0000000..2d14ab8 --- /dev/null +++ b/pygip/data_free_attack/README.md @@ -0,0 +1,86 @@ +# Data-free Model Extraction Attacks + +This directory contains an implementation of data-free model extraction attacks on Graph Neural Networks (GNNs). + +## Files + +1. `example.py`: Interactive script demonstrating how to run data-free attacks +2. `models/`: + - `generator.py`: Graph generator implementation + - `victim.py`: Victim model implementations +3. `attacks/`: + - `attack1.py`: Type I Attack implementation + - `attack2.py`: Type II Attack implementation + - `attack3.py`: Type III Attack implementation + +## Running Data-free Attacks + +The `example.py` script provides an interactive way to run data-free attacks on GNN models. Here's how to use it: + +```bash +python example.py +``` + +When you run the script, it will: +1. Load the Cora dataset +2. Create and train a victim model +3. Prompt you to choose an attack type: + ``` + Choose attack type (1, 2, or 3): + ``` +4. Run the selected attack with the following default parameters: + ```python + noise_dim = 32 + num_nodes = 500 + num_queries = 300 + generator_lr = 1e-6 + surrogate_lr = 0.001 + n_generator_steps = 2 + n_surrogate_steps = 5 + ``` + +### Attack Types + +1. Type I Attack: Basic model extraction attack +2. Type II Attack: Enhanced extraction with improved query strategy +3. Type III Attack: Advanced extraction with additional model architecture considerations + +Choose the attack type by entering the corresponding number (1, 2, or 3) when prompted. + +### Sample Output + +``` +Epoch 10/200, Train Loss: 1.7342, Val Loss: 1.8183, Val Acc: 0.7460 +Epoch 20/200, Train Loss: 1.3186, Val Loss: 1.5902, Val Acc: 0.7860 +Epoch 30/200, Train Loss: 0.8908, Val Loss: 1.3175, Val Acc: 0.7880 +Epoch 40/200, Train Loss: 0.5930, Val Loss: 1.0948, Val Acc: 0.7860 +Epoch 50/200, Train Loss: 0.4184, Val Loss: 0.9633, Val Acc: 0.7940 +Epoch 60/200, Train Loss: 0.3414, Val Loss: 0.8969, Val Acc: 0.7900 +Epoch 70/200, Train Loss: 0.2943, Val Loss: 0.8568, Val Acc: 0.7900 +Epoch 80/200, Train Loss: 0.2577, Val Loss: 0.8343, Val Acc: 0.7940 +Epoch 90/200, Train Loss: 0.2487, Val Loss: 0.8058, Val Acc: 0.7960 +Epoch 100/200, Train Loss: 0.2310, Val Loss: 0.7731, Val Acc: 0.7880 +Epoch 110/200, Train Loss: 0.2129, Val Loss: 0.7825, Val Acc: 0.7900 +Epoch 120/200, Train Loss: 0.2092, Val Loss: 0.7696, Val Acc: 0.7920 +Epoch 130/200, Train Loss: 0.1865, Val Loss: 0.7548, Val Acc: 0.7940 +Epoch 140/200, Train Loss: 0.1748, Val Loss: 0.7522, Val Acc: 0.7960 +Epoch 150/200, Train Loss: 0.1769, Val Loss: 0.7385, Val Acc: 0.7940 +Epoch 160/200, Train Loss: 0.1682, Val Loss: 0.7552, Val Acc: 0.7920 +Epoch 170/200, Train Loss: 0.1557, Val Loss: 0.7254, Val Acc: 0.7880 +Epoch 180/200, Train Loss: 0.1608, Val Loss: 0.7346, Val Acc: 0.7940 +Epoch 190/200, Train Loss: 0.1517, Val Loss: 0.7433, Val Acc: 0.7860 +Epoch 200/200, Train Loss: 0.1482, Val Loss: 0.7290, Val Acc: 0.7940 +Victim Model Accuracy: 0.8070 + +Choose attack type (1, 2, or 3): 2 + +Running Type II Attack... +Attacking: 100%|██████████████████████████████| 300/300 [01:09<00:00, 4.29it/s, Gen Loss=-0.3422, Surr Loss=0.4532] +Type II Attack - Surrogate Model Accuracy: 0.8090 +``` + +The script will display: +1. Training progress of the victim model, showing loss and validation accuracy +2. Final victim model accuracy +3. Progress bar during the attack +4. Final surrogate model accuracy diff --git a/pygip/data_free_attack/attacks/attack1.py b/pygip/data_free_attack/attacks/attack1.py new file mode 100644 index 0000000..3f81619 --- /dev/null +++ b/pygip/data_free_attack/attacks/attack1.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm + +class TypeIAttack: + def __init__(self, generator, surrogate_model, victim_model, device, + noise_dim, num_nodes, feature_dim, + generator_lr=1e-6, surrogate_lr=0.001, + n_generator_steps=2, n_surrogate_steps=5): + self.generator = generator + self.surrogate_model = surrogate_model + self.victim_model = victim_model + self.device = device + self.noise_dim = noise_dim + self.num_nodes = num_nodes + self.feature_dim = feature_dim + + self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=generator_lr) + self.surrogate_optimizer = optim.Adam(self.surrogate_model.parameters(), lr=surrogate_lr) + + self.criterion = nn.CrossEntropyLoss() + self.n_generator_steps = n_generator_steps + self.n_surrogate_steps = n_surrogate_steps + + def generate_graph(self): + z = torch.randn(1, self.noise_dim).to(self.device) + features, adj = self.generator(z) + edge_index = self.generator.adj_to_edge_index(adj) + return features, edge_index + + def train_generator(self): + self.generator.train() + self.surrogate_model.eval() + + total_loss = 0 + for _ in range(self.n_generator_steps): + self.generator_optimizer.zero_grad() + + features, edge_index = self.generate_graph() + + with torch.no_grad(): + victim_output = self.victim_model(features, edge_index) + surrogate_output = self.surrogate_model(features, edge_index) + + loss = -self.criterion(surrogate_output, victim_output.argmax(dim=1)) + + # Zeroth-order optimization with multiple random directions + epsilon = 1e-6 + num_directions = 2 + estimated_gradient = torch.zeros_like(features) + + for _ in range(num_directions): + u = torch.randn_like(features) + perturbed_features = features + epsilon * u + + with torch.no_grad(): + perturbed_victim_output = self.victim_model(perturbed_features, edge_index) + perturbed_surrogate_output = self.surrogate_model(perturbed_features, edge_index) + perturbed_loss = -self.criterion(perturbed_surrogate_output, perturbed_victim_output.argmax(dim=1)) + + estimated_gradient += (perturbed_loss - loss) / epsilon * u + + estimated_gradient /= num_directions + features.grad = estimated_gradient + + self.generator_optimizer.step() + total_loss += loss.item() + + return total_loss / self.n_generator_steps + + def train_surrogate(self): + self.generator.eval() + self.surrogate_model.train() + + total_loss = 0 + for _ in range(self.n_surrogate_steps): + self.surrogate_optimizer.zero_grad() + + features, edge_index = self.generate_graph() + + with torch.no_grad(): + victim_output = self.victim_model(features, edge_index) + surrogate_output = self.surrogate_model(features, edge_index) + + loss = self.criterion(surrogate_output, victim_output.argmax(dim=1)) + + loss.backward() + torch.nn.utils.clip_grad_norm_(self.surrogate_model.parameters(), max_norm=1.0) + self.surrogate_optimizer.step() + + total_loss += loss.item() + + return total_loss / self.n_surrogate_steps + + def attack(self, num_queries, log_interval=10): + generator_losses = [] + surrogate_losses = [] + + pbar = tqdm(range(num_queries), desc="Attacking") + for query in pbar: + gen_loss = self.train_generator() + surr_loss = self.train_surrogate() + + generator_losses.append(gen_loss) + surrogate_losses.append(surr_loss) + + if (query + 1) % log_interval == 0: + pbar.set_postfix({ + 'Gen Loss': f"{gen_loss:.4f}", + 'Surr Loss': f"{surr_loss:.4f}" + }) + + return self.surrogate_model, generator_losses, surrogate_losses + +def run_attack(generator, surrogate_model, victim_model, num_queries, device, + noise_dim, num_nodes, feature_dim): + attack = TypeIAttack(generator, surrogate_model, victim_model, device, + noise_dim, num_nodes, feature_dim) + return attack.attack(num_queries) diff --git a/pygip/data_free_attack/attacks/attack2.py b/pygip/data_free_attack/attacks/attack2.py new file mode 100644 index 0000000..8ec06f4 --- /dev/null +++ b/pygip/data_free_attack/attacks/attack2.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm + +class TypeIIAttack: + def __init__(self, generator, surrogate_model, victim_model, device, + noise_dim, num_nodes, feature_dim, + generator_lr=1e-6, surrogate_lr=0.001, + n_generator_steps=2, n_surrogate_steps=5): + self.generator = generator + self.surrogate_model = surrogate_model + self.victim_model = victim_model + self.device = device + self.noise_dim = noise_dim + self.num_nodes = num_nodes + self.feature_dim = feature_dim + + self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=generator_lr) + self.surrogate_optimizer = optim.Adam(self.surrogate_model.parameters(), lr=surrogate_lr) + + self.criterion = nn.CrossEntropyLoss() + self.n_generator_steps = n_generator_steps + self.n_surrogate_steps = n_surrogate_steps + + def generate_graph(self): + z = torch.randn(1, self.noise_dim).to(self.device) + features, adj = self.generator(z) + edge_index = self.generator.adj_to_edge_index(adj) + return features, edge_index + + def train_generator(self): + self.generator.train() + self.surrogate_model.eval() + + total_loss = 0 + for _ in range(self.n_generator_steps): + self.generator_optimizer.zero_grad() + + features, edge_index = self.generate_graph() + + with torch.no_grad(): + victim_output = self.victim_model(features, edge_index) + surrogate_output = self.surrogate_model(features, edge_index) + + # In Type II, we use the surrogate model's gradient directly + loss = -self.criterion(surrogate_output, victim_output.argmax(dim=1)) + loss.backward() + + self.generator_optimizer.step() + total_loss += loss.item() + + return total_loss / self.n_generator_steps + + def train_surrogate(self): + self.generator.eval() + self.surrogate_model.train() + + total_loss = 0 + for _ in range(self.n_surrogate_steps): + self.surrogate_optimizer.zero_grad() + + features, edge_index = self.generate_graph() + + with torch.no_grad(): + victim_output = self.victim_model(features, edge_index) + surrogate_output = self.surrogate_model(features, edge_index) + + loss = self.criterion(surrogate_output, victim_output.argmax(dim=1)) + + loss.backward() + torch.nn.utils.clip_grad_norm_(self.surrogate_model.parameters(), max_norm=1.0) + self.surrogate_optimizer.step() + + total_loss += loss.item() + + return total_loss / self.n_surrogate_steps + + def attack(self, num_queries, log_interval=10): + generator_losses = [] + surrogate_losses = [] + + pbar = tqdm(range(num_queries), desc="Attacking") + for query in pbar: + gen_loss = self.train_generator() + surr_loss = self.train_surrogate() + + generator_losses.append(gen_loss) + surrogate_losses.append(surr_loss) + + if (query + 1) % log_interval == 0: + pbar.set_postfix({ + 'Gen Loss': f"{gen_loss:.4f}", + 'Surr Loss': f"{surr_loss:.4f}" + }) + + return self.surrogate_model, generator_losses, surrogate_losses + +def run_attack(generator, surrogate_model, victim_model, num_queries, device, + noise_dim, num_nodes, feature_dim): + attack = TypeIIAttack(generator, surrogate_model, victim_model, device, + noise_dim, num_nodes, feature_dim) + return attack.attack(num_queries) diff --git a/pygip/data_free_attack/attacks/attack3.py b/pygip/data_free_attack/attacks/attack3.py new file mode 100644 index 0000000..e57692d --- /dev/null +++ b/pygip/data_free_attack/attacks/attack3.py @@ -0,0 +1,115 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from tqdm import tqdm + +class TypeIIIAttack: + def __init__(self, generator, surrogate_model1, surrogate_model2, victim_model, device, + noise_dim, num_nodes, feature_dim, + generator_lr=1e-6, surrogate_lr=0.001, + n_generator_steps=2, n_surrogate_steps=5): + self.generator = generator + self.surrogate_model1 = surrogate_model1 + self.surrogate_model2 = surrogate_model2 + self.victim_model = victim_model + self.device = device + self.noise_dim = noise_dim + self.num_nodes = num_nodes + self.feature_dim = feature_dim + + self.generator_optimizer = optim.Adam(self.generator.parameters(), lr=generator_lr) + self.surrogate_optimizer1 = optim.Adam(self.surrogate_model1.parameters(), lr=surrogate_lr) + self.surrogate_optimizer2 = optim.Adam(self.surrogate_model2.parameters(), lr=surrogate_lr) + + self.criterion = nn.CrossEntropyLoss() + self.n_generator_steps = n_generator_steps + self.n_surrogate_steps = n_surrogate_steps + + def generate_graph(self): + z = torch.randn(1, self.noise_dim).to(self.device) + features, adj = self.generator(z) + edge_index = self.generator.adj_to_edge_index(adj) + return features, edge_index + + def train_generator(self): + self.generator.train() + self.surrogate_model1.eval() + self.surrogate_model2.eval() + + total_loss = 0 + for _ in range(self.n_generator_steps): + self.generator_optimizer.zero_grad() + + features, edge_index = self.generate_graph() + + surrogate_output1 = self.surrogate_model1(features, edge_index) + surrogate_output2 = self.surrogate_model2(features, edge_index) + + # Compute disagreement loss + loss = -torch.mean(torch.std(torch.stack([surrogate_output1, surrogate_output2]), dim=0)) + loss.backward() + + self.generator_optimizer.step() + total_loss += loss.item() + + return total_loss / self.n_generator_steps + + def train_surrogate(self): + self.generator.eval() + self.surrogate_model1.train() + self.surrogate_model2.train() + + total_loss = 0 + for _ in range(self.n_surrogate_steps): + self.surrogate_optimizer1.zero_grad() + self.surrogate_optimizer2.zero_grad() + + features, edge_index = self.generate_graph() + + with torch.no_grad(): + victim_output = self.victim_model(features, edge_index) + surrogate_output1 = self.surrogate_model1(features, edge_index) + surrogate_output2 = self.surrogate_model2(features, edge_index) + + loss1 = self.criterion(surrogate_output1, victim_output.argmax(dim=1)) + loss2 = self.criterion(surrogate_output2, victim_output.argmax(dim=1)) + + # Combine losses and backpropagate once + combined_loss = loss1 + loss2 + combined_loss.backward() + + torch.nn.utils.clip_grad_norm_(self.surrogate_model1.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_(self.surrogate_model2.parameters(), max_norm=1.0) + + self.surrogate_optimizer1.step() + self.surrogate_optimizer2.step() + + total_loss += combined_loss.item() / 2 + + return total_loss / self.n_surrogate_steps + + def attack(self, num_queries, log_interval=10): + generator_losses = [] + surrogate_losses = [] + + pbar = tqdm(range(num_queries), desc="Attacking") + for query in pbar: + gen_loss = self.train_generator() + surr_loss = self.train_surrogate() + + generator_losses.append(gen_loss) + surrogate_losses.append(surr_loss) + + if (query + 1) % log_interval == 0: + pbar.set_postfix({ + 'Gen Loss': f"{gen_loss:.4f}", + 'Surr Loss': f"{surr_loss:.4f}" + }) + + return (self.surrogate_model1, self.surrogate_model2), generator_losses, surrogate_losses + +def run_attack(generator, surrogate_model1, surrogate_model2, victim_model, num_queries, device, + noise_dim, num_nodes, feature_dim): + attack = TypeIIIAttack(generator, surrogate_model1, surrogate_model2, victim_model, device, + noise_dim, num_nodes, feature_dim) + return attack.attack(num_queries) diff --git a/pygip/data_free_attack/example.py b/pygip/data_free_attack/example.py new file mode 100644 index 0000000..b17fc99 --- /dev/null +++ b/pygip/data_free_attack/example.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch_geometric.datasets import Planetoid +from torch_geometric.transforms import NormalizeFeatures +from models.generator import GraphGenerator +from models.victim import create_victim_model_cora +from attacks.attack1 import TypeIAttack +from attacks.attack2 import TypeIIAttack +from attacks.attack3 import TypeIIIAttack + +def train_victim_model(model, data, epochs=200, lr=0.01, weight_decay=5e-4): + optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) + model.train() + for epoch in range(epochs): + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = nn.functional.nll_loss(out[data.train_mask], data.y[data.train_mask]) + loss.backward() + optimizer.step() + + if (epoch + 1) % 10 == 0: + model.eval() + with torch.no_grad(): + val_out = model(data.x, data.edge_index) + val_loss = nn.functional.nll_loss(val_out[data.val_mask], data.y[data.val_mask]) + val_acc = (val_out[data.val_mask].argmax(dim=1) == data.y[data.val_mask]).float().mean() + model.train() + print(f'Epoch {epoch+1}/{epochs}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, Val Acc: {val_acc.item():.4f}') + +def evaluate_model(model_output, data): + if isinstance(model_output, tuple): + # Handle case where we have two surrogate models + model1, model2 = model_output + model1.eval() + model2.eval() + with torch.no_grad(): + # Get predictions from both models + out1 = model1(data.x, data.edge_index) + out2 = model2(data.x, data.edge_index) + # Average the predictions + out = (out1 + out2) / 2 + pred = out.argmax(dim=1) + else: + # Handle single model case + model_output.eval() + with torch.no_grad(): + out = model_output(data.x, data.edge_index) + pred = out.argmax(dim=1) + + correct = pred[data.test_mask] == data.y[data.test_mask] + accuracy = int(correct.sum()) / int(data.test_mask.sum()) + return accuracy + +def run_attack(victim_model, data, device, attack_type): + # Initialize generator and surrogate model + noise_dim = 32 + num_nodes = 500 + feature_dim = data.num_features + output_dim = data.y.max().item() + 1 # Calculate number of classes + generator = GraphGenerator(noise_dim, num_nodes, feature_dim, generator_type='cosine').to(device) + surrogate_model = create_victim_model_cora().to(device) + + # Attack parameters + num_queries = 300 + generator_lr = 1e-6 + surrogate_lr = 0.001 + n_generator_steps = 2 + n_surrogate_steps = 5 + + if attack_type == 1: + attack = TypeIAttack(generator, surrogate_model, victim_model, device, + noise_dim, num_nodes, feature_dim, generator_lr, surrogate_lr, + n_generator_steps, n_surrogate_steps) + elif attack_type == 2: + attack = TypeIIAttack(generator, surrogate_model, victim_model, device, + noise_dim, num_nodes, feature_dim, generator_lr, surrogate_lr, + n_generator_steps, n_surrogate_steps) + elif attack_type == 3: + surrogate_model2 = create_victim_model_cora().to(device) + attack = TypeIIIAttack(generator, surrogate_model, surrogate_model2, victim_model, device, + noise_dim, num_nodes, feature_dim, generator_lr, surrogate_lr, + n_generator_steps, n_surrogate_steps) + else: + raise ValueError("Invalid attack type. Please choose 1, 2, or 3.") + + print(f"\nRunning Type {attack_type} Attack...") + trained_surrogate, _, _ = attack.attack(num_queries) + surrogate_accuracy = evaluate_model(trained_surrogate, data) + print(f"Type {attack_type} Attack - Surrogate Model Accuracy: {surrogate_accuracy:.4f}") + +def main(): + # Set up device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Load Cora dataset + dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures()) + data = dataset[0].to(device) + + # Create and train victim model + victim_model = create_victim_model_cora().to(device) + train_victim_model(victim_model, data) + + # Evaluate victim model + victim_accuracy = evaluate_model(victim_model, data) + print(f"Victim Model Accuracy: {victim_accuracy:.4f}") + + # Get attack type from user + while True: + try: + attack_type = int(input("\nChoose attack type (1, 2, or 3): ")) + if attack_type in [1, 2, 3]: + break + else: + print("Please enter 1, 2, or 3.") + except ValueError: + print("Please enter a valid number (1, 2, or 3).") + + # Run selected attack + run_attack(victim_model, data, device, attack_type) + +if __name__ == "__main__": + main() diff --git a/pygip/data_free_attack/models/generator.py b/pygip/data_free_attack/models/generator.py new file mode 100644 index 0000000..ae8637e --- /dev/null +++ b/pygip/data_free_attack/models/generator.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv + +class GraphGenerator(nn.Module): + def __init__(self, noise_dim, num_nodes, feature_dim, generator_type='cosine', threshold=0.1): + super(GraphGenerator, self).__init__() + self.noise_dim = noise_dim + self.num_nodes = num_nodes + self.feature_dim = feature_dim + self.generator_type = generator_type + self.threshold = threshold + + # Feature generator + self.feature_gen = nn.Sequential( + nn.Linear(noise_dim, 128), + nn.ReLU(), + nn.Linear(128, 256), + nn.ReLU(), + nn.Linear(256, num_nodes * feature_dim), + nn.Tanh() + ) + + # Full parameterization structure generator + if generator_type == 'full_param': + self.structure_gen = nn.Sequential( + nn.Linear(noise_dim, 128), + nn.ReLU(), + nn.Linear(128, 256), + nn.ReLU(), + nn.Linear(256, num_nodes * num_nodes), + nn.Sigmoid() + ) + + def forward(self, z): + # Generate features + features = self.feature_gen(z).view(self.num_nodes, self.feature_dim) + + # Generate adjacency matrix + if self.generator_type == 'cosine': + adj = self.cosine_similarity_generator(features) + elif self.generator_type == 'full_param': + adj = self.full_param_generator(z) + else: + raise ValueError("Invalid generator type. Choose 'cosine' or 'full_param'.") + + # Normalize adjacency matrix + adj = adj / adj.sum(1, keepdim=True).clamp(min=1) + + return features, adj + + def cosine_similarity_generator(self, features): + # Compute cosine similarity + norm_features = F.normalize(features, p=2, dim=1) + adj = torch.mm(norm_features, norm_features.t()) + + # Apply threshold + adj = (adj > self.threshold).float() + + # Remove self-loops + adj = adj * (1 - torch.eye(self.num_nodes, device=adj.device)) + + return adj + + def full_param_generator(self, z): + adj = self.structure_gen(z).view(self.num_nodes, self.num_nodes) + + # Make symmetric + adj = (adj + adj.t()) / 2 + + # Remove self-loops + adj = adj * (1 - torch.eye(self.num_nodes, device=adj.device)) + + return adj + + def adj_to_edge_index(self, adj): + return adj.nonzero().t() + + def self_supervised_training(self, x, adj, model): + # Implement self-supervised denoising task + self.train() + + # Add noise to features + noise = torch.randn_like(x) * 0.1 + noisy_x = x + noise + + # Use the model to denoise + edge_index = self.adj_to_edge_index(adj) + denoised_x = model(noisy_x, edge_index) + + # Compute reconstruction loss + loss = F.mse_loss(denoised_x, x) + + return loss + +class DenoisingModel(nn.Module): + def __init__(self, input_dim, hidden_dim): + super(DenoisingModel, self).__init__() + self.conv1 = GCNConv(input_dim, hidden_dim) + self.conv2 = GCNConv(hidden_dim, input_dim) + + def forward(self, x, edge_index): + x = F.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + return x diff --git a/pygip/data_free_attack/models/surrogate.py b/pygip/data_free_attack/models/surrogate.py new file mode 100644 index 0000000..7d7e4c1 --- /dev/null +++ b/pygip/data_free_attack/models/surrogate.py @@ -0,0 +1,43 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv + +class SurrogateModel(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout_rate=0.5): + super(SurrogateModel, self).__init__() + self.convs = nn.ModuleList() + self.convs.append(GCNConv(input_dim, hidden_dim)) + + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_dim, hidden_dim)) + + self.convs.append(GCNConv(hidden_dim, output_dim)) + self.dropout_rate = dropout_rate + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs[:-1]): + x = conv(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout_rate, training=self.training) + + x = self.convs[-1](x, edge_index) + return F.softmax(x, dim=1) + + def train_step(self, generator, victim_model, optimizer, criterion, device): + self.train() + optimizer.zero_grad() + + z = torch.randn(1, generator.noise_dim).to(device) + features, adj = generator(z) + edge_index = generator.adj_to_edge_index(adj) + + with torch.no_grad(): + victim_output = victim_model(features, edge_index) + surrogate_output = self(features, edge_index) + + loss = criterion(surrogate_output, victim_output.argmax(dim=1)) + loss.backward() + optimizer.step() + + return loss.item() diff --git a/pygip/data_free_attack/models/victim.py b/pygip/data_free_attack/models/victim.py new file mode 100644 index 0000000..bc2aa7b --- /dev/null +++ b/pygip/data_free_attack/models/victim.py @@ -0,0 +1,49 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv + +class VictimModel(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2): + super(VictimModel, self).__init__() + self.convs = nn.ModuleList() + self.convs.append(GCNConv(input_dim, hidden_dim)) + + for _ in range(num_layers - 2): + self.convs.append(GCNConv(hidden_dim, hidden_dim)) + + self.convs.append(GCNConv(hidden_dim, output_dim)) + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs[:-1]): + x = conv(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=0.25, training=self.training) # Paper: p=0.5 + + x = self.convs[-1](x, edge_index) + return F.log_softmax(x, dim=1) + +def create_victim_model_cora(): + input_dim = 1433 + hidden_dim = 64 # Paper: 128 + output_dim = 7 + return VictimModel(input_dim, hidden_dim, output_dim) + +def create_victim_model_computers(): + input_dim = 767 + hidden_dim = 64 # Paper: 128 + output_dim = 10 + return VictimModel(input_dim, hidden_dim, output_dim) + +def create_victim_model_pubmed(): + input_dim = 500 + hidden_dim = 64 # Paper: 128 + output_dim = 3 + return VictimModel(input_dim, hidden_dim, output_dim) + +def create_victim_model_ogb_arxiv(): + input_dim = 128 + hidden_dim = 128 # Paper: 256 + output_dim = 40 + num_layers = 2 # Paper: 3 + return VictimModel(input_dim, hidden_dim, output_dim, num_layers) diff --git a/requirements.txt b/requirements.txt index 93cb5aa..45b5466 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,45 +1,45 @@ +# System packages (commented out as they're typically managed by the system) # bzip2==1.0.8 # ca-certificates==2024.7.2 # libffi==3.4.4 # ncurses==6.4 # openssl==3.0.14 -# pip==24.0 -# python==3.10.14 # readline==8.2 -# setuptools==69.5.1 # sqlite==3.45.3 # tk==8.6.14 -# wheel==0.43.0 # xz==5.4.6 # zlib==1.2.13 +numpy>=1.23.5,<2.0.0 torch==2.3.0 +pandas==2.2.2 +scipy==1.14.0 +torch-geometric==2.5.0 +networkx==3.3 +ogb==1.3.6 +dgl==2.2.1 +tqdm==4.66.4 +pyyaml==6.0.1 +requests==2.32.3 +fsspec==2024.6.1 +psutil==6.0.0 +torchdata==0.7.1 +python-dateutil==2.9.0.post0 +pytz==2024.1 +tzdata==2024.1 +typing-extensions==4.12.2 annotated-types==0.7.0 +pydantic==2.8.2 +pydantic-core==2.20.1 certifi==2024.7.4 charset-normalizer==3.3.2 filelock==3.15.4 -fsspec==2024.6.1 idna==3.7 jinja2==3.1.4 markupsafe==2.1.5 mpmath==1.3.0 -networkx==3.3 -numpy==2.0.1 -pandas==2.2.2 -psutil==6.0.0 -pydantic==2.8.2 -pydantic-core==2.20.1 -python-dateutil==2.9.0.post0 -pytz==2024.1 -pyyaml==6.0.1 -requests==2.32.3 -scipy==1.14.0 six==1.16.0 sympy==1.13.1 -torchdata==0.7.1 -tqdm==4.66.4 -typing-extensions==4.12.2 -tzdata==2024.1 urllib3==2.2.2 -dgl==2.2.1 -torch_geometric -packaging +scikit-learn==1.3.0 +matplotlib==3.7.5 +packaging==23.2