Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/attack/EGSteal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pygip.datasets import *
from pygip.models.attack.EGSteal.EGSteal import EGSteal
from pygip.utils.hardware import set_device

set_device("cuda:0") # cpu, cuda:0


def egsteal():
dataset = MUTAGGraphClassification(api_type='pyg')
egsteal = EGSteal(dataset,query_shadow_ratio=0.3)
egsteal.attack()

if __name__ == '__main__':
egsteal()
4 changes: 4 additions & 0 deletions pygip/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
Photo,
CoauthorCS,
CoauthorPhysics,
MUTAG,
MUTAGGraphClassification
)

__all__ = [
Expand All @@ -18,4 +20,6 @@
'Photo',
'CoauthorCS',
'CoauthorPhysics',
'MUTAG',
'MUTAGGraphClassification'
]
31 changes: 31 additions & 0 deletions pygip/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from torch_geometric.datasets import Reddit
from torch_geometric.datasets import TUDataset # ENZYMES

## Added for EGSteal
from torch_geometric.transforms import Constant

def dgl_to_tg(dgl_graph):
edge_index = torch.stack(dgl_graph.edges())
Expand Down Expand Up @@ -565,3 +567,32 @@ def load_dgl_data(self):
dataset = YelpDataset(raw_dir=self.path)
self.graph_dataset = dataset
self.graph_data = dataset[0]

class MUTAGGraphClassification(Dataset):
def __init__(self, api_type='pyg', path='./data'):
super().__init__(api_type, path)

def _load_meta_data(self):
if self.api_type == 'pyg':
ds = self.graph_dataset
self.num_features = int(ds.num_node_features)
self.num_classes = int(ds.num_classes)
self.num_graphs = int(len(ds))
self.num_edge_features = int(ds.num_edge_features) if ds.num_edge_features is not None else 0
else:
super()._load_meta_data()

def load_pyg_data(self):
self.dataset_name = 'Mutagenicity'
temp_dataset = TUDataset(root=self.path,name='Mutagenicity')
data_transform = None
if temp_dataset.num_node_features == 0:
print("\nNo node features found. Adding constant node features (all ones).")
data_transform = Constant(value=1, cat=False)
self.graph_dataset = TUDataset(root=self.path,name='Mutagenicity',transform=data_transform)
num_graphs = len(self.graph_dataset)
print(f"\nTotal number of graphs in {self.dataset_name}: {num_graphs}")
print(f"Node features dimension: {self.graph_dataset.num_node_features}")
print(f"Edge features dimension: {self.graph_dataset.num_edge_features if hasattr(self.graph_dataset, 'num_edge_features') else 'N/A'}")
print(f"Number of classes: {self.graph_dataset.num_classes if hasattr(self.graph_dataset, 'num_classes') else 'N/A'}")
return self.graph_dataset
Loading