Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ Icon

#virtual environments folder
.venv

/pygip.egg-info/
/*.egg-info
/__pycache__/
/venv/
/.venv/
*.pyc
.DS_Store
12 changes: 12 additions & 0 deletions examples/defense_cora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from datasets import Cora
from models.defense.RandomWM import RandomWM

def main():
dataset = Cora()
mead = RandomWM(dataset, 0.25) # sampling ratio
print("Initialized defense; starting defend()...")
mead.defend()
print("Defense finished.")

if __name__ == "__main__":
main()
13 changes: 13 additions & 0 deletions examples/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# examples/example.py
from datasets import Cora
from models.attack import ModelExtractionAttack6 as MEA

def main():
dataset = Cora()
mea = MEA(dataset, 0.25)
print("Running Attack-6 on Cora...")
res = mea.attack()
print("Results:", res)

if __name__ == "__main__":
main()
21 changes: 21 additions & 0 deletions examples/mea_cora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from datasets import Cora

MEA = None
try:
from models.attack import ModelExtractionAttack0 as MEA
except Exception:
try:
from models.attack import ModelExtractionAttack as MEA
except Exception:
# Direct import from the submodule if not re-exported
from models.attack.mea import ModelExtractionAttack0 as MEA

def main():
dataset = Cora()
mea = MEA(dataset, 0.25) # sampling ratio
print("Initialized MEA; starting attack...")
mea.attack()
print("Attack finished.")

if __name__ == "__main__":
main()
12 changes: 12 additions & 0 deletions examples/mea_cora_attack6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from datasets import Cora
from models.attack import ModelExtractionAttack6 as MEA

def main():
dataset = Cora()
mea = MEA(dataset, 0.25)
print("Initialized MEA-6; starting attack...")
res = mea.attack()
print("Attack-6 finished. Results:", res)

if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion models/attack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
ModelExtractionAttack2,
ModelExtractionAttack3,
ModelExtractionAttack4,
ModelExtractionAttack5
ModelExtractionAttack5,
ModelExtractionAttack6,
)

__all__ = [
Expand All @@ -16,4 +17,5 @@
'ModelExtractionAttack3',
'ModelExtractionAttack4',
'ModelExtractionAttack5',
'ModelExtractionAttack6',
]
182 changes: 182 additions & 0 deletions models/attack/mea/MEA.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,3 +1111,185 @@ def attack(self):
print(f"Error type: {type(e)}")
torch.cuda.empty_cache()
raise



class ModelExtractionAttack6(ModelExtractionAttack):
def __init__(self, dataset, attack_node_fraction, model_path=None, alpha=0.8):
super().__init__(dataset, attack_node_fraction, model_path)
self.alpha = alpha

def get_nonzero_indices(self, matrix_row):
return np.where(matrix_row != 0)[0]

def attack(self):
"""
Main attack procedure.

1. Samples a subset of nodes (`sub_graph_node_index`) for querying.
2. Synthesizes features for neighboring nodes and their neighbors.
3. Builds a sub-graph, trains a new GCN on it, and evaluates
fidelity & accuracy w.r.t. the target model.
"""
try:
torch.cuda.empty_cache()
g = self.graph.clone().to(self.device)
g_matrix = g.adjacency_matrix().to_dense().cpu().numpy()
del g

sub_graph_node_index = np.random.choice(
self.num_nodes, self.attack_node_num, replace=False).tolist()

batch_size = 32
features_query = self.features.clone()

syn_nodes = []
for node_index in sub_graph_node_index:
one_step_node_index = self.get_nonzero_indices(g_matrix[node_index]).tolist()
syn_nodes.extend(one_step_node_index)

for first_order_node_index in one_step_node_index:
two_step_node_index = self.get_nonzero_indices(g_matrix[first_order_node_index]).tolist()
syn_nodes.extend(two_step_node_index)

sub_graph_syn_node_index = list(set(syn_nodes) - set(sub_graph_node_index))
total_sub_nodes = list(set(sub_graph_syn_node_index + sub_graph_node_index))

# Process synthetic nodes in batches
for i in range(0, len(sub_graph_syn_node_index), batch_size):
batch_indices = sub_graph_syn_node_index[i:i + batch_size]

for node_index in batch_indices:
features_query[node_index] = 0
one_step_node_index = self.get_nonzero_indices(g_matrix[node_index]).tolist()
one_step_node_index = list(set(one_step_node_index).intersection(set(sub_graph_node_index)))

num_one_step = len(one_step_node_index)
if num_one_step > 0:
for first_order_node_index in one_step_node_index:
this_node_degree = len(self.get_nonzero_indices(g_matrix[first_order_node_index]))
features_query[node_index] += (
self.features[first_order_node_index] * self.alpha /
torch.sqrt(torch.tensor(num_one_step * this_node_degree, device=self.device))
)

two_step_nodes = []
for first_order_node_index in one_step_node_index:
two_step_nodes.extend(self.get_nonzero_indices(g_matrix[first_order_node_index]).tolist())

total_two_step_node_index = list(set(two_step_nodes) - set(one_step_node_index))
total_two_step_node_index = list(
set(total_two_step_node_index).intersection(set(sub_graph_node_index)))

num_two_step = len(total_two_step_node_index)
if num_two_step > 0:
for second_order_node_index in total_two_step_node_index:
this_node_first_step_nodes = self.get_nonzero_indices(
g_matrix[second_order_node_index]).tolist()
this_node_second_step_nodes = set()

for nodes_in_this_node in this_node_first_step_nodes:
this_node_second_step_nodes.update(
self.get_nonzero_indices(g_matrix[nodes_in_this_node]).tolist())

this_node_second_step_nodes = this_node_second_step_nodes - set(this_node_first_step_nodes)
this_node_second_degree = len(this_node_second_step_nodes)

if this_node_second_degree > 0:
features_query[node_index] += (
self.features[second_order_node_index] * (1 - self.alpha) /
torch.sqrt(
torch.tensor(num_two_step * this_node_second_degree, device=self.device))
)

torch.cuda.empty_cache()

# Update masks
for i in range(self.num_nodes):
if i in sub_graph_node_index:
self.test_mask[i] = 0
self.train_mask[i] = 1
elif i in sub_graph_syn_node_index:
self.test_mask[i] = 1
self.train_mask[i] = 0
else:
self.test_mask[i] = 1
self.train_mask[i] = 0

# Create subgraph adjacency matrix
sub_g = np.zeros((len(total_sub_nodes), len(total_sub_nodes)))
for sub_index in range(len(total_sub_nodes)):
sub_g[sub_index] = g_matrix[total_sub_nodes[sub_index], total_sub_nodes]

del g_matrix

sub_train_mask = self.train_mask[total_sub_nodes]
sub_features = features_query[total_sub_nodes]
sub_labels = self.labels[total_sub_nodes]

# Get query labels
self.net1.eval()
with torch.no_grad():
g = self.graph.to(self.device)
logits_query = self.net1(g, features_query)
_, labels_query = torch.max(logits_query, dim=1)
sub_labels_query = labels_query[total_sub_nodes]
del logits_query

# Create DGL graph
sub_g = nx.from_numpy_array(sub_g)
sub_g.remove_edges_from(nx.selfloop_edges(sub_g))
sub_g.add_edges_from(zip(sub_g.nodes(), sub_g.nodes()))
sub_g = DGLGraph(sub_g)
sub_g = sub_g.to(self.device)

degs = sub_g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
norm = norm.to(self.device)
sub_g.ndata['norm'] = norm.unsqueeze(1)

# Train extraction model
net = GCN(self.num_features, self.num_classes).to(self.device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, weight_decay=5e-4)
best_performance_metrics = GraphNeuralNetworkMetric()

print("=========Model Extracting(Attack)==========================")
for epoch in tqdm(range(200)):
net.train()
logits = net(sub_g, sub_features)
logp = F.log_softmax(logits, dim=1)
loss = F.nll_loss(logp[sub_train_mask], sub_labels_query[sub_train_mask])

optimizer.zero_grad()
loss.backward()
optimizer.step()

with torch.no_grad():
focus_gnn_metrics = GraphNeuralNetworkMetric(
0, 0, net, g, self.features, self.test_mask, self.labels, labels_query
)
focus_gnn_metrics.evaluate()

best_performance_metrics.fidelity = max(
best_performance_metrics.fidelity, focus_gnn_metrics.fidelity)
best_performance_metrics.accuracy = max(
best_performance_metrics.accuracy, focus_gnn_metrics.accuracy)

if epoch % 10 == 0:
torch.cuda.empty_cache()

print("========================Final results (Attack 6):=========================================")
print(best_performance_metrics)
results = {
"fidelity": float(best_performance_metrics.fidelity),
"accuracy": float(best_performance_metrics.accuracy),
}
self.net2 = net
return results

except RuntimeError as e:
print(f"Attack 6 Runtime error: {e}")
torch.cuda.empty_cache()
raise

2 changes: 1 addition & 1 deletion reqs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ async-timeout==5.0.1
attrs==25.3.0
certifi==2024.7.4
charset-normalizer==3.3.2
dgl==2.2.0
dgl==2.2.1
filelock==3.15.4
frozenlist==1.5.0
fsspec==2024.6.1
Expand Down