diff --git a/experimental_examples/README.md b/experimental_examples/README.md new file mode 100644 index 0000000..4e0b07a --- /dev/null +++ b/experimental_examples/README.md @@ -0,0 +1,91 @@ + +# ๐Ÿ” GNNFingers Experimentals + +This folder contains experimental scripts of the **GNNFingers** attack and defense pipeline, +integrated with [PyGIP](https://github.com/yushundong/PyGIP). + +The purpose of this example is to hold experimental scripts that reproduce and extend fingerprint-based verification of GNN models. + + + +## โ–ถ๏ธ Usage + +All experiments are launched via the CLI: + +```bash +python experimental_examples/cli.py --dataset Cora --joint_steps 50 +``` + +### Common options + +- `--dataset {Cora,Citeseer,Pubmed}` + Which dataset to use. + +- `--joint_steps INT` + Number of training steps for the joint optimization of fingerprints and univerifier. + +- `--num_graphs INT` + Number of fingerprint probe graphs. + +- `--num_nodes INT` + Number of nodes per probe graph. + +- `--edge_density FLOAT` + Edge density for fingerprint graphs (default 0.05). + +- `--proj_every INT` + Projection frequency during fingerprint optimization. + +- `--node_sample INT` + Node sampling factor for graph generation. + +- `--device {cpu,cuda}` + Device for training (defaults to `cuda` if available). + +- `--mode {attack,defense}` + Run attack pipeline (default) or defense pipeline. + +- `--clean` + Remove old `.pt` and `.json` artifacts before running. + +--- + +## ๐Ÿงช Examples + +### Quick test (small run) +```bash +python experimental_examples/cli.py --dataset Cora --joint_steps 10 --num_graphs 8 --num_nodes 16 --clean +``` + +### Full attack run +```bash +python experimental_examples/cli.py --dataset Cora --joint_steps 300 --num_graphs 64 --num_nodes 32 --edge_density 0.05 +``` + +### Defense run +```bash +python experimental_examples/cli.py --dataset Cora --mode defense +``` + +--- + +## ๐Ÿ“ฆ Outputs + +Running the pipeline produces: + +- **Model checkpoints (`*.pt`)** + - `target_main.pt`, `ft_last.pt`, `reinit_last.pt`, etc. +- **Fingerprint artifacts** + - `fingerprints.pt`, `univerifier.pt` +- **Verification metrics** + - `verification_metrics.json` (contains ROC_AUC, ARUC, robustness, etc.) + +--- + +## ๐Ÿ“ Notes + +- The implementation follows the guidelines in `IMPLEMENTATION.md`. +- The `attack()` and `defense()` functions are public entrypoints, with helpers defined internally. +- Use the `--clean` flag to avoid piling up old artifacts across runs. + +--- diff --git a/experimental_examples/attacker.py b/experimental_examples/attacker.py new file mode 100644 index 0000000..d598105 --- /dev/null +++ b/experimental_examples/attacker.py @@ -0,0 +1,434 @@ +# Attack class โ€” model training, fingerprint learning, evaluation, and defense hooks + +from torch_geometric.nn import GCNConv, SAGEConv +import os, json, random +from typing import List, Tuple, Optional + +import torch +import torch.nn.functional as F +from dataclasses import dataclass + +from models import SmallGCN, SmallSAGE +from fingerprints import LearnableFingerprint, Univerifier + +@dataclass +class FingerprintSpecLocal: + num_graphs: int =64 + num_nodes: int = 32 + edge_density: float = 0.05 + proj_every: int = 25 + update_feat: bool = True + update_adj: bool = True + node_sample: int = 0 + + +def make_deterministic(seed: int =123): + # Set seeds for reproducibility. + os.environ['PYTHONHASHSEED'] = str(seed) + random.seed(seed) + torch.manual_seed(seed) + try: + torch.cuda.manual_seed_all(seed) + except Exception: + pass + torch.backends.cudnn.deterministic =True + torch.backends.cudnn.benchmark = False + + +class GNNFingersAttack(object): + supported_api_types = {'pyg'} + + def __init__(self, dataset, attack_node_fraction: float = 0.3, model_path: Optional[str] =None, + fp_cfg: FingerprintSpecLocal = FingerprintSpecLocal(),joint_steps: int = 300, device: Optional[torch.device] = None): + self.device = torch.device(device) if device else (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') ) + print(f"[GNNFingersAttack] using device: {self.device}") + + self.dataset =dataset + self.graph_data = dataset.graph_data.to(self.device) + + self.num_features = dataset.num_features + self.num_classes = dataset.num_classes + + self.attack_node_fraction = attack_node_fraction + self.model_path = model_path + + self.fp_cfg = fp_cfg + self.joint_steps = joint_steps + + self.feat_min= float(self.graph_data.x.min().item()) + self.feat_max = float(self.graph_data.x.max().item()) + + def build_target(self, kind='gcn'): + if kind == 'sage': + return SmallSAGE(self.num_features,64, self.num_classes).to(self.device) + return SmallGCN(self.num_features, 64,self.num_classes).to(self.device) + + def run_train_epoch(self, m, feat, ei, mask, labels, opt, edge_weight=None): + m.train() + opt.zero_grad() + out =m(feat, ei, edge_weight=edge_weight) + loss = F.cross_entropy(out[mask], labels[mask]) + loss.backward() + opt.step() + return loss.item() + + @torch.no_grad() + def eval_split(self, m, feat, ei, labels, mask, edge_weight=None): + m.eval() + logits = m(feat, ei, edge_weight=edge_weight) + predict = logits.argmax(dim=1) + accuracy= (predict[mask] == labels[mask]).float().mean().item() + return accuracy, logits + + # Target model utilities + def _train_target_model(self, arch='gcn', epochs=200): + make_deterministic(42) + target_model = self.build_target(kind=arch) + opt = torch.optim.Adam(target_model.parameters(), lr=0.01,weight_decay=5e-4) + + best_val, best_state = 0.0, None + for ep in range(1, epochs+1): + loss = self.run_train_epoch(target_model, self.graph_data.x, self.graph_data.edge_index,self.graph_data.train_mask, self.graph_data.y, opt) + va, _ = self.eval_split(target_model, self.graph_data.x, self.graph_data.edge_index, self.graph_data.y,self.graph_data.val_mask) + te, _ =self.eval_split(target_model, self.graph_data.x, self.graph_data.edge_index, self.graph_data.y, self.graph_data.test_mask) + if va > best_val: + best_val = va + best_state = {k: v.detach().cpu().clone() for k, v in target_model.state_dict().items()} + if ep in (1, 50, 100, 150, 200): + print(f"[target] ep= {ep:3d} | loss = {loss:.4f} | val = {va:.4f} | test = {te:.4f}") + + if best_state: + target_model.load_state_dict({k: v.to(self.device) for k, v in best_state.items()}) + return target_model + + def _load_model(self, path): + m = self.build_target(kind=os.environ.get('TARGET_ARCH', 'gcn')) + m.load_state_dict(torch.load(path, map_location=self.device)) + m.to(self.device).eval() + return m + + # Attack model training (fingerprints + univerifier) + def _train_attack_model(self, target_model_path: Optional[str] =None, joint_steps: Optional[int] = None): + if joint_steps is None: + joint_steps = self.joint_steps + + suspects: List[Tuple[str, str, int]]= [] + + if target_model_path: + target_model = self._load_model(target_model_path) + else: + target_model = self._train_target_model() + + torch.save(target_model.state_dict(), './target_main.pt') + suspects.append(('target', './target_main.pt', 1)) + + def copy_model_like(m): + new_m = self.build_target(kind='gcn' if isinstance(m, SmallGCN) else 'sage') + new_m.load_state_dict(m.state_dict()) + return new_m + + def fine_tune_last_layer(m, steps=10, lr=0.01): + for p in m.parameters(): + p.requires_grad = False + last = None + for module in m.modules(): + if isinstance(module, (GCNConv, SAGEConv)): + last = module + for p in last.parameters(): + p.requires_grad = True + opt = torch.optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), lr=lr,weight_decay=5e-4) + for _ in range(steps): + self.run_train_epoch(m, self.graph_data.x, self.graph_data.edge_index, self.graph_data.train_mask, self.graph_data.y, opt) + return m + + def partial_reinit_and_train(m,steps=10, lr=0.01): + last =None + for module in m.modules(): + if isinstance(module,(GCNConv, SAGEConv)): + last = module + if hasattr(last,'reset_parameters'): + last.reset_parameters() + opt = torch.optim.Adam(m.parameters(), lr=lr, weight_decay=5e-4) + for _ in range(steps): + self.run_train_epoch(m, self.graph_data.x, self.graph_data.edge_index, self.graph_data.train_mask, self.graph_data.y, opt) + return m + + def fine_tune_all(m, steps=10, lr=0.005): + opt = torch.optim.Adam(m.parameters(), lr=lr, weight_decay=5e-4) + for _ in range(steps): + self.run_train_epoch(m, self.graph_data.x, self.graph_data.edge_index, self.graph_data.train_mask, self.graph_data.y, opt) + return m + + make_deterministic(11) + m1 = copy_model_like(target_model) + torch.save(fine_tune_last_layer(m1).state_dict(),'./ft_last.pt') + suspects.append(('ft_last', './ft_last.pt', 1)) + + make_deterministic(12) + m2 = copy_model_like(target_model) + torch.save(partial_reinit_and_train(m2).state_dict(),'./reinit_last.pt') + suspects.append(('reinit_last', './reinit_last.pt', 1)) + + make_deterministic(13) + m3 = copy_model_like(target_model) + torch.save(fine_tune_all(m3).state_dict(), './ft_all.pt') + suspects.append(('ft_all', './ft_all.pt', 1)) + + for seed in (100, 101, 102): + make_deterministic(seed) + mn = self.build_target(kind=os.environ.get('NEG_ARCH', 'gcn')) + opt = torch.optim.Adam(mn.parameters(), lr=0.01, weight_decay=5e-4) + for _ in range(100): + self.run_train_epoch(mn, self.graph_data.x, self.graph_data.edge_index, self.graph_data.train_mask, self.graph_data.y, opt) + path = f'./neg_{seed}.pt' + torch.save(mn.state_dict(), path) + suspects.append((f'neg_{seed}', path, 0)) + + print('[info] Built suspects:', [s[0] for s in suspects]) + + model_entries = [(nm,self._load_model(pth), lbl) for (nm, pth, lbl) in suspects] + + fp_pool: List[LearnableFingerprint] = [ + LearnableFingerprint(self.fp_cfg.num_nodes, self.num_features, self.fp_cfg.edge_density, device=self.device).to(self.device) + for _ in range(self.fp_cfg.num_graphs) + ] + + dummy_sig = self.get_signature_from_model(model_entries[0][1], fp_pool, m_nodes=self.fp_cfg.node_sample) + uv = Univerifier(in_dim=dummy_sig.numel()).to(self.device) + + fp_params =[] + for fp in fp_pool: + if self.fp_cfg.update_feat: + fp_params.append(fp.feat) + if self.fp_cfg.update_adj: + fp_params.append(fp.adj_param) + opt_fp = torch.optim.Adam(fp_params,lr=0.05) + opt_uv = torch.optim.Adam(uv.parameters(), lr=1e-3,weight_decay=1e-4) + + print(f"[info] Joint steps: {joint_steps} | proj_every={self.fp_cfg.proj_every} | update_feat={self.fp_cfg.update_feat} | update_adj={self.fp_cfg.update_adj}") + + for step in range(1, joint_steps+1): + uv.train() + batch_inputs, batch_labels = [],[] + for (nm, mdl, lbl) in model_entries: + sig_pieces = [] + for fp in fp_pool: + out = fp.forward(mdl) + if self.fp_cfg.node_sample and 0 < self.fp_cfg.node_sample < fp.num_nodes: + idx = torch.randperm(fp.num_nodes,device=out.device)[:self.fp_cfg.node_sample] + probs = out[idx].softmax(dim =-1).mean(dim =0) + else: + probs =out.softmax(dim =-1).mean(dim=0) + sig_pieces.append(probs) + sig_all = torch.cat(sig_pieces,dim = 0) + batch_inputs.append(sig_all.unsqueeze(0)) + batch_labels.append(torch.tensor([lbl],device=self.device, dtype=torch.long)) + + Xb= torch.cat(batch_inputs, dim=0) + yb = torch.cat(batch_labels,dim=0) + logits = uv(Xb.float()) + loss =F.cross_entropy(logits, yb) + + opt_uv.zero_grad() + opt_fp.zero_grad() + loss.backward() + + with torch.no_grad(): + for fp in fp_pool: + if self.fp_cfg.update_feat: + fp.feat.clamp_(self.feat_min, self.feat_max) + + opt_uv.step() + opt_fp.step() + + if self.fp_cfg.update_adj and (step % self.fp_cfg.proj_every == 0 or step == joint_steps): + for fp in fp_pool: + fp.harden_topk(self.fp_cfg.edge_density) + + if step % 25 == 0 or step == 1 or step == joint_steps: + with torch.no_grad(): + probs = logits.softmax(dim=1)[:,1] + avg_pos = probs[(yb==1)].mean().item() if (yb==1).any() else float('nan') + avg_neg = probs[(yb==0)].mean().item() if (yb==0).any() else float('nan') + print(f"[joint] step={step:3d} | loss={loss.item():.4f} | avg_pos={avg_pos:.3f} | avg_neg={avg_neg:.3f}") + + torch.save(uv.state_dict(),'./univerifier.pt') + torch.save({f'fp_{i}': (fp.feat.detach().cpu(), fp.adj_param.detach().cpu()) for i, fp in enumerate(fp_pool)}, './fingerprints.pt') + + metrics = self.evaluate_curves(uv, model_entries, fp_pool) + with open('./verification_metrics.json', 'w') as f: + json.dump(metrics, f, indent=2) + print('Saved verification_metrics.json with labels and probs included.') + + return metrics + + # Public attack entrypoint + def attack(self, *args, **kwargs): + return self._train_attack_model(*args, **kwargs) + + + # Defense interface + helpers + def defense(self, method: str = 'default'): + print(f"[defense] running defense with method={method}") + + # Train or load the victim (target) model + target = self._train_target_model() + + # Optionally train a surrogate attack model (for adversarial defenses that need it) + surrogate = self._train_surrogate_model() # returns None if not used + + #Train defense model (depending on method) + defense_model = self._train_defense_model(method = method, target_model = target, surrogate_model = surrogate) + + # Test defense: Ruyns the fingerprint-based verifier against the defended model + metrics = None + try: + # reuse attack pipeline but evaluate on defended model as one of the suspects + torch.save(defense_model.state_dict(), './defended_model.pt') + suspects = [('defended', './defended_model.pt', 1)] + + # keep a copy of original suspects by reusing _train_attack_model's suspect creation + # but here we will only run the verifier's evaluation stage using the defended model as positive + # To be quick, load a subset of models from disk (target, negatives) and include defended + # For simplicity, reuse existing saved files if present + saved = [] + if os.path.exists('./target_main.pt'): + saved.append(('target', './target_main.pt', 1)) + for seed in (100,101,102): + p = f'./neg_{seed}.pt' + if os.path.exists(p): + saved.append((f'neg_{seed}', p, 0)) + # append defended + saved.append(('defended', './defended_model.pt', 1)) + + model_entries = [(nm, self._load_model(pth), lbl) for (nm, pth, lbl) in saved] + + # build fingerprints (reuse fingerprint config) + fp_pool: List[LearnableFingerprint] = [ + LearnableFingerprint(self.fp_cfg.num_nodes, self.num_features, self.fp_cfg.edge_density, device=self.device).to(self.device) + for _ in range(self.fp_cfg.num_graphs) + ] + + dummy_sig = self.get_signature_from_model(model_entries[0][1], fp_pool, m_nodes=self.fp_cfg.node_sample) + uv = Univerifier(in_dim = dummy_sig.numel()).to(self.device) + + # quick joint training of univerifier only (fp fixed) to see verification metrics against defended model + X = self.collect_signatures_all(model_entries, fp_pool, m_nodes=self.fp_cfg.node_sample).to(self.device).float() + logits = uv(X) + prob_pos = logits.softmax(dim=1)[:,1] + # compute simple ROC-AUC for this quick eval + from sklearn.metrics import roc_auc_score + labels = [lbl for (_,_,lbl) in model_entries] + auc_val = roc_auc_score(labels, prob_pos.detach().cpu().numpy()) + metrics = {'quick_ROC_AUC': float(auc_val)} + except Exception as e: + print('[defense] evaluation failed: ', e) + + return { + 'defense_model': defense_model, + 'surrogate': surrogate, + 'metrics': metrics + } + + def _train_defense_model(self, method: str = 'default', target_model = None, surrogate_model = None): + #Trains a defense model. This is a stub that demonstrates the expected interface. + + print(f"[_train_defense_model] training defense using method={method}") + + # simple baseline + if target_model is None: + target_model = self._train_target_model() + + def partial_reinit_and_train(m, steps=10, lr=0.01): + last = None + for module in m.modules(): + if isinstance(module, (GCNConv, SAGEConv)): + last = module + if hasattr(last, 'reset_parameters'): + last.reset_parameters() + opt = torch.optim.Adam(m.parameters(), lr=lr, weight_decay=5e-4) + for _ in range( steps): + self.run_train_epoch(m, self.graph_data.x, self.graph_data.edge_index, self.graph_data.train_mask, self.graph_data.y, opt) + return m + + defended = partial_reinit_and_train(self.build_target(), steps=10) + return defended + + def _train_surrogate_model(self): + #Trains a surrogate (attack) model. Returns a model instance or None if not needed. + + print('[ _train_surrogate_model ] training surrogate model (simple retrain)') + make_deterministic(21) + mn= self.build_target() + opt = torch.optim.Adam(mn.parameters(), lr=0.01, weight_decay=5e-4) + for _ in range(50): + self.run_train_epoch(mn, self.graph_data.x, self.graph_data.edge_index, self.graph_data.train_mask, self.graph_data.y, opt) + return mn + + # Signature utilities + @torch.no_grad() + def get_signature_from_model(self, m, fps: List[LearnableFingerprint],m_nodes: int = 0): + pieces = [] + for fp in fps: + out = fp.forward(m) + if m_nodes and 0 < m_nodes < fp.num_nodes: + idx =torch.randperm(fp.num_nodes, device=out.device)[:m_nodes] + probs = out[idx].softmax(dim=-1).mean(dim=0) + else: + probs =out.softmax(dim=-1).mean(dim=0) + pieces.append(probs) + return torch.cat(pieces, dim=0) + + @torch.no_grad() + def collect_signatures_all(self, models, fps, m_nodes =0): + bag = [] + for (_, mdl, _lbl) in models: + bag.append(self.get_signature_from_model(mdl,fps, m_nodes = m_nodes).unsqueeze(0)) + X = torch.cat(bag, dim=0) + return X + + def evaluate_curves(self, uv,models, fps, thresholds = None): + uv.eval() + if thresholds is None: + thresholds = torch.linspace(0.0, 1.0, steps = 101) + labels = torch.tensor([lbl for (_, _, lbl) in models], device = self.device) + X = self.collect_signatures_all(models, fps, m_nodes = self.fp_cfg.node_sample).to(self.device).float() + logits =uv(X) + prob_pos = logits.softmax(dim = 1)[:,1] + + labels_list= labels.detach().cpu().numpy().tolist() + probs_list= prob_pos.detach().cpu().numpy().tolist() + + pos_mask = labels == 1 + neg_mask = labels ==0 + + rob_list, uniq_list, acc_list = [], [],[] + for t in thresholds: + pred_pos = (prob_pos >=t).long() + tp = ((pred_pos == 1) & pos_mask).sum().item() + tn =((pred_pos == 0) &neg_mask).sum().item() + p_total = pos_mask.sum().item() + n_total = neg_mask.sum().item() + robustness =tp /max(1, p_total) + uniqueness = tn / max(1, n_total) + mean_acc = (tp + tn) / max(1, (p_total + n_total)) + rob_list.append(robustness) + uniq_list.append(uniqueness) + acc_list.append(mean_acc) + + import numpy as np + inter = np.minimum(np.array(rob_list), np.array(uniq_list)) + aruc= float(np.trapezoid(inter, x=np.linspace(0, 1, len(inter)))) + from sklearn.metrics import roc_auc_score + auc_val = roc_auc_score(labels_list, probs_list) + + return { + 'thresholds': thresholds.tolist(), + 'robustness': rob_list, + 'uniqueness': uniq_list, + 'mean_accuracy': acc_list, + 'ARUC': aruc, + 'ROC_AUC': auc_val, + 'labels': labels_list, + 'probs': probs_list + } diff --git a/experimental_examples/base_attack.py b/experimental_examples/base_attack.py new file mode 100644 index 0000000..97f4ee2 --- /dev/null +++ b/experimental_examples/base_attack.py @@ -0,0 +1,34 @@ +#Base attack class - Base attack as per the implementation guide +import torch +from typing import Optional + +class BaseAttack(object): + supported_api_types = set() + supported_datasets = set() + + def __init__(self, dataset, attack_node_fraction: float = None, + model_path: Optional[str] = None, device: Optional[torch.device] = None): + self.device = torch.device(device) if device else (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) + print(f"[BaseAttack] Using device: {self.device}") + + self.dataset = dataset + self.graph_dataset = dataset.graph_dataset + self.graph_data = dataset.graph_data.to(self.device) + + self.num_nodes = dataset.num_nodes + self.num_features = dataset.num_features + self.num_classes = dataset.num_classes + + self.attack_node_fraction = attack_node_fraction + self.model_path =model_path + + self._check_dataset_compatibility() + + def _check_dataset_compatibility(self): + if self.supported_api_types and (self.dataset.api_type not in self.supported_api_types): + raise RuntimeError('Dataset api type not supported') + if self.supported_datasets and (self.dataset.dataset_name not in self.supported_datasets): + print('[BaseAttack] Warning: dataset name not listed') + + def attack(self): + raise NotImplementedError \ No newline at end of file diff --git a/experimental_examples/cli.py b/experimental_examples/cli.py new file mode 100644 index 0000000..01988cc --- /dev/null +++ b/experimental_examples/cli.py @@ -0,0 +1,47 @@ +#Cli - cli commands +if __name__ == '__main__': + import argparse,torch + from dataset import Dataset + from attacker import GNNFingersAttack, FingerprintSpecLocal + import os + import glob + + parser = argparse.ArgumentParser() + parser.add_argument('--dataset',default='Cora') + parser.add_argument('--joint_steps',type=int, default=300) + parser.add_argument('--num_graphs',type=int, default=64) + parser.add_argument('--num_nodes',type=int, default=32) + parser.add_argument('--edge_density',type=float,default=0.05) + parser.add_argument('--proj_every',type=int,default=25) + parser.add_argument('--node_sample',type=int,default=0) + parser.add_argument('--device',default=None) + parser.add_argument('--mode',choices=['attack','defense'],default='attack',help='Run attack pipeline or defense pipeline') + parser.add_argument('--clean', action='store_true', help='Remove old .pt and metrics files before running') + args = parser.parse_args() + + if args.clean: + print('[clean] removing old .pt and metrics files...') + for f in glob.glob('*.pt') + glob.glob('*.json'): + try: + os.remove(f) + print(' removed',f) + except Exception as e: + print(' could not remove',f,e) + + ds = Dataset(api_type='pyg', path='./data', name=args.dataset) + ds.to(torch.device(args.device) if args.device else (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))) + + fp_cfg = FingerprintSpecLocal(num_graphs = args.num_graphs,num_nodes = args.num_nodes,edge_density = args.edge_density,proj_every = args.proj_every,node_sample = args.node_sample) + attack = GNNFingersAttack(ds, attack_node_fraction = 0.3, fp_cfg = fp_cfg,joint_steps = args.joint_steps, device = args.device) + + if args.mode == 'attack': + metrics = attack.attack() + print('\nSummary:') + print('ROC_AUC:', metrics['ROC_AUC']) + print('ARUC:', metrics['ARUC']) + print('Saved artifacts: univerifier.pt, fingerprints.pt, verification_metrics.json') + else: # defense + res = attack.defense(method='default') + print('\nDefense result:') + print('quick metrics:', res.get('metrics')) + print('Saved defended model:', './defended_model.pt' if res.get('defense_model') is not None else 'none') diff --git a/experimental_examples/data/Planetoid/Cora/processed/data.pt b/experimental_examples/data/Planetoid/Cora/processed/data.pt new file mode 100644 index 0000000..5ded299 Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/processed/data.pt differ diff --git a/experimental_examples/data/Planetoid/Cora/processed/pre_filter.pt b/experimental_examples/data/Planetoid/Cora/processed/pre_filter.pt new file mode 100644 index 0000000..965c404 Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/processed/pre_filter.pt differ diff --git a/experimental_examples/data/Planetoid/Cora/processed/pre_transform.pt b/experimental_examples/data/Planetoid/Cora/processed/pre_transform.pt new file mode 100644 index 0000000..965c404 Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/processed/pre_transform.pt differ diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.allx b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.allx new file mode 100644 index 0000000..44d53b1 Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.allx differ diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.ally b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.ally new file mode 100644 index 0000000..04fbd0b Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.ally differ diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.graph b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.graph new file mode 100644 index 0000000..4d3bf85 Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.graph differ diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.test.index b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.test.index new file mode 100644 index 0000000..ded8092 --- /dev/null +++ b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.test.index @@ -0,0 +1,1000 @@ +2692 +2532 +2050 +1715 +2362 +2609 +2622 +1975 +2081 +1767 +2263 +1725 +2588 +2259 +2357 +1998 +2574 +2179 +2291 +2382 +1812 +1751 +2422 +1937 +2631 +2510 +2378 +2589 +2345 +1943 +1850 +2298 +1825 +2035 +2507 +2313 +1906 +1797 +2023 +2159 +2495 +1886 +2122 +2369 +2461 +1925 +2565 +1858 +2234 +2000 +1846 +2318 +1723 +2559 +2258 +1763 +1991 +1922 +2003 +2662 +2250 +2064 +2529 +1888 +2499 +2454 +2320 +2287 +2203 +2018 +2002 +2632 +2554 +2314 +2537 +1760 +2088 +2086 +2218 +2605 +1953 +2403 +1920 +2015 +2335 +2535 +1837 +2009 +1905 +2636 +1942 +2193 +2576 +2373 +1873 +2463 +2509 +1954 +2656 +2455 +2494 +2295 +2114 +2561 +2176 +2275 +2635 +2442 +2704 +2127 +2085 +2214 +2487 +1739 +2543 +1783 +2485 +2262 +2472 +2326 +1738 +2170 +2100 +2384 +2152 +2647 +2693 +2376 +1775 +1726 +2476 +2195 +1773 +1793 +2194 +2581 +1854 +2524 +1945 +1781 +1987 +2599 +1744 +2225 +2300 +1928 +2042 +2202 +1958 +1816 +1916 +2679 +2190 +1733 +2034 +2643 +2177 +1883 +1917 +1996 +2491 +2268 +2231 +2471 +1919 +1909 +2012 +2522 +1865 +2466 +2469 +2087 +2584 +2563 +1924 +2143 +1736 +1966 +2533 +2490 +2630 +1973 +2568 +1978 +2664 +2633 +2312 +2178 +1754 +2307 +2480 +1960 +1742 +1962 +2160 +2070 +2553 +2433 +1768 +2659 +2379 +2271 +1776 +2153 +1877 +2027 +2028 +2155 +2196 +2483 +2026 +2158 +2407 +1821 +2131 +2676 +2277 +2489 +2424 +1963 +1808 +1859 +2597 +2548 +2368 +1817 +2405 +2413 +2603 +2350 +2118 +2329 +1969 +2577 +2475 +2467 +2425 +1769 +2092 +2044 +2586 +2608 +1983 +2109 +2649 +1964 +2144 +1902 +2411 +2508 +2360 +1721 +2005 +2014 +2308 +2646 +1949 +1830 +2212 +2596 +1832 +1735 +1866 +2695 +1941 +2546 +2498 +2686 +2665 +1784 +2613 +1970 +2021 +2211 +2516 +2185 +2479 +2699 +2150 +1990 +2063 +2075 +1979 +2094 +1787 +2571 +2690 +1926 +2341 +2566 +1957 +1709 +1955 +2570 +2387 +1811 +2025 +2447 +2696 +2052 +2366 +1857 +2273 +2245 +2672 +2133 +2421 +1929 +2125 +2319 +2641 +2167 +2418 +1765 +1761 +1828 +2188 +1972 +1997 +2419 +2289 +2296 +2587 +2051 +2440 +2053 +2191 +1923 +2164 +1861 +2339 +2333 +2523 +2670 +2121 +1921 +1724 +2253 +2374 +1940 +2545 +2301 +2244 +2156 +1849 +2551 +2011 +2279 +2572 +1757 +2400 +2569 +2072 +2526 +2173 +2069 +2036 +1819 +1734 +1880 +2137 +2408 +2226 +2604 +1771 +2698 +2187 +2060 +1756 +2201 +2066 +2439 +1844 +1772 +2383 +2398 +1708 +1992 +1959 +1794 +2426 +2702 +2444 +1944 +1829 +2660 +2497 +2607 +2343 +1730 +2624 +1790 +1935 +1967 +2401 +2255 +2355 +2348 +1931 +2183 +2161 +2701 +1948 +2501 +2192 +2404 +2209 +2331 +1810 +2363 +2334 +1887 +2393 +2557 +1719 +1732 +1986 +2037 +2056 +1867 +2126 +1932 +2117 +1807 +1801 +1743 +2041 +1843 +2388 +2221 +1833 +2677 +1778 +2661 +2306 +2394 +2106 +2430 +2371 +2606 +2353 +2269 +2317 +2645 +2372 +2550 +2043 +1968 +2165 +2310 +1985 +2446 +1982 +2377 +2207 +1818 +1913 +1766 +1722 +1894 +2020 +1881 +2621 +2409 +2261 +2458 +2096 +1712 +2594 +2293 +2048 +2359 +1839 +2392 +2254 +1911 +2101 +2367 +1889 +1753 +2555 +2246 +2264 +2010 +2336 +2651 +2017 +2140 +1842 +2019 +1890 +2525 +2134 +2492 +2652 +2040 +2145 +2575 +2166 +1999 +2434 +1711 +2276 +2450 +2389 +2669 +2595 +1814 +2039 +2502 +1896 +2168 +2344 +2637 +2031 +1977 +2380 +1936 +2047 +2460 +2102 +1745 +2650 +2046 +2514 +1980 +2352 +2113 +1713 +2058 +2558 +1718 +1864 +1876 +2338 +1879 +1891 +2186 +2451 +2181 +2638 +2644 +2103 +2591 +2266 +2468 +1869 +2582 +2674 +2361 +2462 +1748 +2215 +2615 +2236 +2248 +2493 +2342 +2449 +2274 +1824 +1852 +1870 +2441 +2356 +1835 +2694 +2602 +2685 +1893 +2544 +2536 +1994 +1853 +1838 +1786 +1930 +2539 +1892 +2265 +2618 +2486 +2583 +2061 +1796 +1806 +2084 +1933 +2095 +2136 +2078 +1884 +2438 +2286 +2138 +1750 +2184 +1799 +2278 +2410 +2642 +2435 +1956 +2399 +1774 +2129 +1898 +1823 +1938 +2299 +1862 +2420 +2673 +1984 +2204 +1717 +2074 +2213 +2436 +2297 +2592 +2667 +2703 +2511 +1779 +1782 +2625 +2365 +2315 +2381 +1788 +1714 +2302 +1927 +2325 +2506 +2169 +2328 +2629 +2128 +2655 +2282 +2073 +2395 +2247 +2521 +2260 +1868 +1988 +2324 +2705 +2541 +1731 +2681 +2707 +2465 +1785 +2149 +2045 +2505 +2611 +2217 +2180 +1904 +2453 +2484 +1871 +2309 +2349 +2482 +2004 +1965 +2406 +2162 +1805 +2654 +2007 +1947 +1981 +2112 +2141 +1720 +1758 +2080 +2330 +2030 +2432 +2089 +2547 +1820 +1815 +2675 +1840 +2658 +2370 +2251 +1908 +2029 +2068 +2513 +2549 +2267 +2580 +2327 +2351 +2111 +2022 +2321 +2614 +2252 +2104 +1822 +2552 +2243 +1798 +2396 +2663 +2564 +2148 +2562 +2684 +2001 +2151 +2706 +2240 +2474 +2303 +2634 +2680 +2055 +2090 +2503 +2347 +2402 +2238 +1950 +2054 +2016 +1872 +2233 +1710 +2032 +2540 +2628 +1795 +2616 +1903 +2531 +2567 +1946 +1897 +2222 +2227 +2627 +1856 +2464 +2241 +2481 +2130 +2311 +2083 +2223 +2284 +2235 +2097 +1752 +2515 +2527 +2385 +2189 +2283 +2182 +2079 +2375 +2174 +2437 +1993 +2517 +2443 +2224 +2648 +2171 +2290 +2542 +2038 +1855 +1831 +1759 +1848 +2445 +1827 +2429 +2205 +2598 +2657 +1728 +2065 +1918 +2427 +2573 +2620 +2292 +1777 +2008 +1875 +2288 +2256 +2033 +2470 +2585 +2610 +2082 +2230 +1915 +1847 +2337 +2512 +2386 +2006 +2653 +2346 +1951 +2110 +2639 +2520 +1939 +2683 +2139 +2220 +1910 +2237 +1900 +1836 +2197 +1716 +1860 +2077 +2519 +2538 +2323 +1914 +1971 +1845 +2132 +1802 +1907 +2640 +2496 +2281 +2198 +2416 +2285 +1755 +2431 +2071 +2249 +2123 +1727 +2459 +2304 +2199 +1791 +1809 +1780 +2210 +2417 +1874 +1878 +2116 +1961 +1863 +2579 +2477 +2228 +2332 +2578 +2457 +2024 +1934 +2316 +1841 +1764 +1737 +2322 +2239 +2294 +1729 +2488 +1974 +2473 +2098 +2612 +1834 +2340 +2423 +2175 +2280 +2617 +2208 +2560 +1741 +2600 +2059 +1747 +2242 +2700 +2232 +2057 +2147 +2682 +1792 +1826 +2120 +1895 +2364 +2163 +1851 +2391 +2414 +2452 +1803 +1989 +2623 +2200 +2528 +2415 +1804 +2146 +2619 +2687 +1762 +2172 +2270 +2678 +2593 +2448 +1882 +2257 +2500 +1899 +2478 +2412 +2107 +1746 +2428 +2115 +1800 +1901 +2397 +2530 +1912 +2108 +2206 +2091 +1740 +2219 +1976 +2099 +2142 +2671 +2668 +2216 +2272 +2229 +2666 +2456 +2534 +2697 +2688 +2062 +2691 +2689 +2154 +2590 +2626 +2390 +1813 +2067 +1952 +2518 +2358 +1789 +2076 +2049 +2119 +2013 +2124 +2556 +2105 +2093 +1885 +2305 +2354 +2135 +2601 +1770 +1995 +2504 +1749 +2157 diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.tx b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.tx new file mode 100644 index 0000000..6e856d7 Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.tx differ diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.ty b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.ty new file mode 100644 index 0000000..da1734a Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.ty differ diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.x b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.x new file mode 100644 index 0000000..c4a91d0 Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.x differ diff --git a/experimental_examples/data/Planetoid/Cora/raw/ind.cora.y b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.y new file mode 100644 index 0000000..58e30ef Binary files /dev/null and b/experimental_examples/data/Planetoid/Cora/raw/ind.cora.y differ diff --git a/experimental_examples/dataset.py b/experimental_examples/dataset.py new file mode 100644 index 0000000..26b9a0c --- /dev/null +++ b/experimental_examples/dataset.py @@ -0,0 +1,39 @@ +#Dataset class - load dataset + +import os, random +from typing import Optional + +import torch +from torch_geometric.datasets import Planetoid + +class Dataset(object): + def __init__(self, api_type: str ='pyg', path: str = './data', name: str = 'Cora'): + assert api_type in {'dgl', 'pyg'}, 'API type must be dgl or pyg' + self.api_type= api_type + self.path = path + self.dataset_name =name + + self.graph_dataset=None + self.graph_data = None + + self.num_nodes = 0 + self.num_features= 0 + self.num_classes = 0 + + self._load() + + def _load(self): + if self.api_type == 'pyg': + ds = Planetoid(root=os.path.join(self.path,'Planetoid'), name=self.dataset_name) + self.graph_dataset = ds + self.graph_data = ds[0] + self.num_nodes = ds[0].num_nodes + self.num_features = ds.num_node_features + self.num_classes = ds.num_classes + else: + raise NotImplementedError('dgl api not implemented') + + def to(self, device: Optional[torch.device] = None): + dev =device if device is not None else(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) + self.graph_data= self.graph_data.to(dev) + return self \ No newline at end of file diff --git a/experimental_examples/fingerprints.py b/experimental_examples/fingerprints.py new file mode 100644 index 0000000..9f90941 --- /dev/null +++ b/experimental_examples/fingerprints.py @@ -0,0 +1,64 @@ +#Fingerprints class - learnable fingerprints and Univerifier + +import torch +import torch.nn as nn + +class LearnableFingerprint(nn.Module): + def __init__(self, num_nodes, feat_dim, density=0.05, device='cpu'): + super().__init__() + self.num_nodes = num_nodes + self.feat = nn.Parameter(torch.randn(num_nodes, feat_dim) * 0.1) + self.adj_param = nn.Parameter(torch.randn(num_nodes, num_nodes) * -3.0) + + with torch.no_grad(): + mask = torch.rand(num_nodes, num_nodes, device=device) < density + mask.fill_diagonal_(0) + self.adj_param[mask] = 3.0 + ap = (self.adj_param + self.adj_param.t()) / 2.0 + self.adj_param.copy_(ap) + + src, dst = torch.where(~torch.eye(num_nodes, dtype=torch.bool, device=device)) + self.register_buffer('edge_index_all',torch.stack([src,dst], dim=0)) + + def current_edge_weight(self): + w = torch.sigmoid(self.adj_param) + ew = w[self.edge_index_all[0], self.edge_index_all[1]] + return ew + + def harden_topk(self, edge_density: float): + with torch.no_grad(): + w =torch.sigmoid(self.adj_param) + iu = torch.triu_indices(self.num_nodes, self.num_nodes, offset=1, device=w.device) + wu= w[iu[0], iu[1]] + k = max(1, int(edge_density * wu.numel())) + vals, idx= torch.topk(wu, k) + keep = torch.zeros_like(wu, dtype=torch.bool) + keep[idx] =True + bin_u = torch.zeros_like(wu) + bin_u[keep] = 1.0 + w_new = torch.zeros_like(w) + w_new[iu[0],iu[1]] =bin_u + w_new = w_new + w_new.t() + eps = 1e-3 + w_new.clamp_(0, 1) + p_new = torch.log((w_new + eps) / (1-w_new + eps)) + self.adj_param.copy_((p_new + p_new.t()) / 2.0) + + def forward(self, model): + ew = self.current_edge_weight() + logits = model(self.feat, self.edge_index_all, edge_weight=ew) + return logits + +class Univerifier(nn.Module): + def __init__(self, in_dim, hidden=[128, 64, 32], p=0.1): + super().__init__() + dims = [in_dim]+hidden+[2] + layers = [] + for i in range(len(dims)-2): + layers += [nn.Linear(dims[i], dims[i+1]), nn.LeakyReLU(), nn.Dropout(p)] + layers +=[nn.Linear(dims[-2], dims[-1])] + self.net = nn.Sequential(*layers) + + def forward(self, x): + return self.net(x) + diff --git a/experimental_examples/models.py b/experimental_examples/models.py new file mode 100644 index 0000000..a9a4b36 --- /dev/null +++ b/experimental_examples/models.py @@ -0,0 +1,34 @@ +# Models class - GCN and SAGE backbone models + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, SAGEConv + +class SmallGCN(nn.Module): + def __init__(self, in_ch, hid_ch, out_ch, dropout=0.5): + super().__init__() + self.conv1 = GCNConv(in_ch,hid_ch,normalize=True) + self.conv2 = GCNConv(hid_ch,out_ch,normalize=True) + self.dropout = dropout + + def forward(self, x, edge_index, edge_weight=None): + x = self.conv1(x, edge_index, edge_weight=edge_weight) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index, edge_weight=edge_weight) + return x + +class SmallSAGE(nn.Module): + def __init__(self, in_ch, hid_ch, out_ch, dropout=0.5): + super().__init__() + self.conv1 = SAGEConv(in_ch, hid_ch) + self.conv2 = SAGEConv(hid_ch, out_ch) + self.dropout = dropout + + def forward(self, x, edge_index, edge_weight=None): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x,p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + return x \ No newline at end of file diff --git a/experimental_examples/output.py b/experimental_examples/output.py new file mode 100644 index 0000000..f6c3c6e --- /dev/null +++ b/experimental_examples/output.py @@ -0,0 +1,4 @@ +import json +with open("verification_metrics.json") as f: + metrics = json.load(f) +print(metrics["ROC_AUC"], metrics["ARUC"])