Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,21 @@ dist/

#virtual environments folder
.venv

# Local datasets / artifacts (do not commit)
data/
*.pt
*.pth
*.zip

# accidental local files
2.2.0

# Local datasets / artifacts (do not commit)
data/
*.pt
*.pth
*.zip

# accidental local files
2.2.0
18 changes: 18 additions & 0 deletions examples/egsteal_attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from pygip.datasets import TUGraph
from pygip.models.attack import EGStealAttack

dataset = TUGraph(name="NCI109", api_type="pyg")

config = {
"gnn_backbone": "GIN",
"gnn_layers": 3,
"hidden_dim": 128,
"epochs": 5, # set to 200 later
"batch_size": 64,
"explanation_mode": "CAM",
"align_weight": 1.0,
}

attack = EGStealAttack(dataset, config=config)
print(attack.attack())

1 change: 1 addition & 0 deletions pygip/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Photo,
CoauthorCS,
CoauthorPhysics,
TUGraph,
)

__all__ = [
Expand Down
86 changes: 74 additions & 12 deletions pygip/datasets/datasets.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,44 @@
import dgl
# --- Optional DGL imports (only required when api_type == "dgl") ---
try:
import dgl # type: ignore
from dgl import DGLGraph # type: ignore

# DGL datasets (graph classification / node classification)
from dgl.data import ( # type: ignore
AmazonCoBuyComputerDataset, # Amazon-Computer
AmazonCoBuyPhotoDataset, # Amazon-Photo
CoauthorCSDataset, # Coauthor-CS
CoauthorPhysicsDataset, # Coauthor-Physics
CoraGraphDataset,
CiteseerGraphDataset,
PubmedGraphDataset,
)
except ImportError:
dgl = None
DGLGraph = None

AmazonCoBuyComputerDataset = None
AmazonCoBuyPhotoDataset = None
CoauthorCSDataset = None
CoauthorPhysicsDataset = None
CoraGraphDataset = None
CiteseerGraphDataset = None
PubmedGraphDataset = None
# ---------------------------------------------------------------

try:
import dgl # optional: only needed when api_type == "dgl"
except ImportError:
dgl = None
import numpy as np
import torch
from dgl import DGLGraph
from dgl.data import AmazonCoBuyComputerDataset # Amazon-Computer
from dgl.data import AmazonCoBuyPhotoDataset # Amazon-Photo
from dgl.data import CoauthorCSDataset, CoauthorPhysicsDataset
from dgl.data import FakeNewsDataset
from dgl.data import FlickrDataset
from dgl.data import GINDataset
from dgl.data import MUTAGDataset
from dgl.data import RedditDataset
from dgl.data import YelpDataset
from dgl.data import citation_graph # Cora, CiteSeer, PubMed
try:
import dgl # optional
from dgl import DGLGraph # type: ignore
except ImportError:
dgl = None
DGLGraph = None # type: ignore

from sklearn.model_selection import StratifiedShuffleSplit
from torch_geometric.data import Data as PyGData
from torch_geometric.datasets import Amazon # Amazon Computers, Photo
Expand Down Expand Up @@ -62,6 +89,13 @@ class Dataset(object):
def __init__(self, api_type='dgl', path='./data'):
assert api_type in {'dgl', 'pyg'}, 'API type must be dgl or pyg'
self.api_type = api_type
if self.api_type == "dgl" and dgl is None:
raise ImportError(
"DGL is not installed, but api_type='dgl' was requested. "
"Install DGL (or run on a platform that supports DGL wheels) "
"or use api_type='pyg'."
)

self.path = path
self.dataset_name = self.get_name()

Expand Down Expand Up @@ -260,6 +294,34 @@ def __repr__(self):
f"#Nodes={self.num_nodes}, #Features={self.num_features}, "
f"#Classes={self.num_classes})")

class TUGraph(Dataset):
"""
PyG wrapper for TU graph classification datasets (e.g., NCI109, AIDS, Mutagenicity).
This is graph-level classification, so we set graph_dataset and do not use graph_data.
"""
def __init__(self, name: str, api_type: str = "pyg", path: str = "./data"):
self.name = name
super().__init__(api_type=api_type, path=path)

def get_name(self):
return self.name

def load_dgl_data(self):
raise ImportError("TUGraph only supports api_type='pyg' (DGL not required).")

def load_pyg_data(self):
# torch_geometric.datasets.TUDataset is already imported at top
self.graph_dataset = TUDataset(root=self.path, name=self.name)
self.graph_data = None # graph classification datasets are list-like

def _load_meta_data(self):
# Override because base _load_meta_data assumes a single PyGData in self.graph_data
# For TU datasets, metadata comes from the dataset object.
self.num_nodes = 0 # varies per graph
self.num_features = self.graph_dataset.num_features
self.num_classes = self.graph_dataset.num_classes



class Cora(Dataset):
def __init__(self, api_type='dgl', path='./data'):
Expand Down
60 changes: 32 additions & 28 deletions pygip/models/attack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,35 @@
from .AdvMEA import AdvMEA
from .CEGA import CEGA
from .DataFreeMEA import (
DFEATypeI,
DFEATypeII,
DFEATypeIII
)
from .mea.MEA import (
ModelExtractionAttack0,
ModelExtractionAttack1,
ModelExtractionAttack2,
ModelExtractionAttack3,
ModelExtractionAttack4,
ModelExtractionAttack5
)
from .Realistic import RealisticAttack
"""
Attack module exports.

Some attacks depend on DGL. DGL wheels may be unavailable on some platforms (e.g., macOS).
We import DGL-dependent attacks conditionally so PyG-only workflows still work.
"""

from .egsteal import EGStealAttack

# Optional: if you KNOW any of these are PyG-only, you can import them here.
# For now, we keep everything else behind the DGL gate to avoid import-time crashes.


try:
import dgl # noqa: F401

# Import ALL attacks that require DGL here:
from .AdvMEA import AdvMEA
from .CEGA import CEGA
from .DataFreeMEA import DataFreeMEA
from .Realistic import Realistic

except ImportError:
AdvMEA = None
CEGA = None
DataFreeMEA = None
Realistic = None

__all__ = [
'AdvMEA',
'CEGA',
'RealisticAttack',
'DFEATypeI',
'DFEATypeII',
'DFEATypeIII',
'ModelExtractionAttack0',
'ModelExtractionAttack1',
'ModelExtractionAttack2',
'ModelExtractionAttack3',
'ModelExtractionAttack4',
'ModelExtractionAttack5',
"EGStealAttack",
"AdvMEA",
"CEGA",
"DataFreeMEA",
"Realistic",
]
4 changes: 4 additions & 0 deletions pygip/models/attack/egsteal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .attack import EGStealAttack

__all__ = ["EGStealAttack"]

Loading