diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..78d7ba3 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,91 @@ +#Fingerprinting Graph Neural Networks + +Steps + +1. Create virtual env. Activate it. +2. Install requirements + ```bash + pip install -r requirements.txt + ``` +3. Create folders + + ```bash + mkdir -p data models fingerprints plots + ``` + +### GNN task types: Node Classification, Link Prediction, Graph Classification + +For Node Classification (NC): \ +  Folder name: node_class/ \ +  Filename Suffix: \*\_nc.\* + +For Link Prediction (LP): \ +  Folder name: link_pred/ \ +  Filename Suffix: \*\_lp.\* + +For Graph Classification (GC): \ +  Folder name: graph_class/ \ +  Filename Suffix: \*\_gc.\* + + Example: \ +  `bash + python node_class/train_nc.py ` \ +  `bash + python link_pred/train_lp.py ` \ +  `bash + python graph_class/train_gc.py ` + +### For node classification task on Cora dataset (GCN arch) + +```bash +python node_class/train_nc.py +``` + +```bash +python node_class/fine_tune_pirate_nc.py +``` + +```bash +python node_class/distill_students_nc.py +``` + +```bash +python node_class/train_unrelated_nc.py +``` + +```bash +python node_class/fingerprint_generator_nc.py +``` + +```bash +python node_class/generate_univerifier_dataset_nc.py +``` + +```bash +python train_univerifier.py --dataset fingerprints/univerifier_dataset_nc.pt --fingerprints_path fingerprints/fingerprints_nc.pt --out fingerprints/univerifier_nc.pt +``` + +```bash +python node_class/eval_verifier_nc.py +``` + +Follow similar approach as Node Classification for Link Prediction on Citeseer dataset (GCN arch) and Graph Classification on ENZYMES dataset (Graphsage arch). + +Change argument paths for LP and GC for training univerifier + +```bash +python train_univerifier.py --dataset fingerprints/univerifier_dataset_lp.pt --fingerprints_path fingerprints/fingerprints_lp.pt --out fingerprints/univerifier_lp.pt +``` + +```bash +python train_univerifier.py --dataset fingerprints/univerifier_dataset_gc.pt --fingerprints_path fingerprints/fingerprints_gc.pt --out fingerprints/univerifier_gc.pt +``` + +To evaluate suspect GNNs for NC tasks + ```bash + python node_class/make_suspect_neg_nc.py + ``` + ```bash + python node_class/score_suspect_nc.py --suspect_pt models/suspects/neg_nc_seed9999.pt --suspect_meta models/suspects/neg_nc_seed9999.json + ``` + diff --git a/examples/graph_class/distill_students_gc.py b/examples/graph_class/distill_students_gc.py new file mode 100644 index 0000000..b6632ce --- /dev/null +++ b/examples/graph_class/distill_students_gc.py @@ -0,0 +1,110 @@ +# Positive (pirated) models for GRAPH CLASSIFICATION on ENZYMES via DISTILLATION. +# Teacher: trained GC model loaded from target_model_gc.pt +# Students: GraphSAGE via get_model + +import argparse, json, random, torch +from pathlib import Path + +import torch.nn.functional as F +from torch_geometric.datasets import TUDataset +from torch_geometric.loader import DataLoader +from torch_geometric.transforms import NormalizeFeatures + +from gsage_gc import get_model + + +def set_seed(s: int): + random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) + + +def kd_loss(student_logits, teacher_logits): + return F.mse_loss(student_logits, teacher_logits) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--meta_path', default='models/target_meta_gc.json') + ap.add_argument('--target_path', default='models/target_model_gc.pt') + ap.add_argument('--archs', default='gsage') + ap.add_argument('--epochs', type=int, default=10) + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--count_per_arch', type=int, default=100) + ap.add_argument('--out_dir', type=str, default='models/positives') + ap.add_argument('--batch_size', type=int, default=64) + ap.add_argument('--student_hidden', type=int, default=64) + ap.add_argument('--student_layers', type=int, default=3) + ap.add_argument('--student_dropout', type=float, default=0.5) + args = ap.parse_args() + + set_seed(args.seed) + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + dataset = TUDataset(root='data/ENZYMES', name='ENZYMES', + use_node_attr=True, transform=NormalizeFeatures()) + in_dim = dataset.num_features + num_classes = dataset.num_classes + loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + # Teacher GC model + with open(args.meta_path, 'r') as f: + meta = json.load(f) + arch_t = meta.get('arch', 'gsage') + hidden_t = meta.get('hidden', 64) + layers_t = meta.get('layers', 3) + drop_t = meta.get('dropout', 0.5) + + teacher = get_model(arch_t, in_dim, hidden_t, num_classes, + num_layers=layers_t, dropout=drop_t, pool="mean").to(device) + teacher.load_state_dict(torch.load(args.target_path, map_location='cpu')) + teacher.eval() + + archs = [a.strip() for a in args.archs.split(',') if a.strip()] + saved = [] + + for arch in archs: + for i in range(args.count_per_arch): + # fresh student + student = get_model(arch, in_dim, args.student_hidden, num_classes, + num_layers=args.student_layers, + dropout=args.student_dropout, pool="mean").to(device) + opt = torch.optim.Adam(student.parameters(), lr=args.lr, weight_decay=args.wd) + + for _ in range(args.epochs): + student.train() + for batch in loader: + batch = batch.to(device) + with torch.no_grad(): + t_logits = teacher(batch.x, batch.edge_index, batch=batch.batch) # [B, C] + s_logits = student(batch.x, batch.edge_index, batch=batch.batch) # [B, C] + loss = kd_loss(s_logits, t_logits) + opt.zero_grad(); loss.backward(); opt.step() + + # save student + out_pt = f'{args.out_dir}/distill_gc_{arch}_{i:03d}.pt' + torch.save(student.state_dict(), out_pt) + with open(out_pt.replace('.pt', '.json'), 'w') as f: + json.dump({ + "task": "graph_classification", + "dataset": "ENZYMES", + "arch": arch, + "hidden": args.student_hidden, + "layers": args.student_layers, + "dropout": args.student_dropout, + "pos_kind": "distill", + "teacher_arch": arch_t, + "teacher_hidden": hidden_t, + "teacher_layers": layers_t, + "teacher_dropout": drop_t + }, f, indent=2) + + saved.append(out_pt) + print(f"[distill-gc] saved {out_pt}") + + print(f"Saved {len(saved)} distilled GC positives.") + + +if __name__ == '__main__': + main() diff --git a/examples/graph_class/eval_verifier_gc.py b/examples/graph_class/eval_verifier_gc.py new file mode 100644 index 0000000..1632771 --- /dev/null +++ b/examples/graph_class/eval_verifier_gc.py @@ -0,0 +1,268 @@ +""" +Evaluate a trained Univerifier on GRAPH CLASSIFICATION (ENZYMES) positives ({target ∪ F+}) +and negatives (F−) using saved GC fingerprints. Produces Robustness/Uniqueness, ARUC, Mean Test Accuracy, KL Divergence. +""" + +import argparse, glob, json, os, torch +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path + +from torch_geometric.datasets import TUDataset +from torch_geometric.transforms import NormalizeFeatures +from torch_geometric.utils import dense_to_sparse, to_undirected +import torch.nn.functional as F + +from gsage_gc import get_model # GraphSAGE GC with pooling + +import torch.nn as nn + + +class FPVerifier(nn.Module): + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid(), + ) + + def forward(self, x): + return self.net(x) + + +def list_paths_from_globs(globs_str): + globs = [g.strip() for g in globs_str.split(",") if g.strip()] + paths = [] + for g in globs: + paths.extend(glob.glob(g)) + return sorted(paths) + + +def load_model_from_pt(pt_path, in_dim, num_classes): + meta_path = pt_path.replace(".pt", ".json") + j = json.load(open(meta_path, "r")) + m = get_model( + j.get("arch", "gsage"), + in_dim, + j.get("hidden", 64), + num_classes, + num_layers=j.get("layers", 3), + dropout=j.get("dropout", 0.5), + pool="mean", + ) + m.load_state_dict(torch.load(pt_path, map_location="cpu")) + m.eval() + return m + + +# GC fingerprint forward: model -> graph logits +@torch.no_grad() +def forward_on_fp(model, fp): + X = fp["X"] + A = fp["A"] + n = X.size(0) + + A_bin = (A > 0.5).float() + A_sym = torch.maximum(A_bin, A_bin.t()) + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + idx = torch.arange(n, dtype=torch.long) + edge_index = torch.stack([idx, (idx + 1) % n], dim=0) + edge_index = to_undirected(edge_index) + + batch = X.new_zeros(n, dtype=torch.long) + logits = model(X, edge_index, batch=batch) + return logits.squeeze(0) + + +@torch.no_grad() +def concat_for_model(model, fps): + parts = [forward_on_fp(model, fp) for fp in fps] + return torch.cat(parts, dim=0) + +def softmax_logits(x): + return F.softmax(x, dim=-1) + +def sym_kl(p, q, eps=1e-8): + p = p.clamp(min=eps); q = q.clamp(min=eps) + kl1 = (p * (p.log() - q.log())).sum(dim=-1) + kl2 = (q * (q.log() - p.log())).sum(dim=-1) + return 0.5 * (kl1 + kl2) + +@torch.no_grad() +def model_gc_kl_to_target(suspect, target, fps): + """ + Average symmetric KL over fingerprints (graph-level). + """ + vals = [] + for fp in fps: + t = softmax_logits(forward_on_fp(target, fp)).unsqueeze(0) # [1,C] + s = softmax_logits(forward_on_fp(suspect, fp)).unsqueeze(0) # [1,C] + d = sym_kl(s, t) # [1] + vals.append(float(d.item())) + return float(np.mean(vals)) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--fingerprints_path', type=str, default='fingerprints/fingerprints_gc.pt') + ap.add_argument('--verifier_path', type=str, default='fingerprints/univerifier_gc.pt') + ap.add_argument('--target_path', type=str, default='models/target_model_gc.pt') + ap.add_argument('--target_meta', type=str, default='models/target_meta_gc.json') + ap.add_argument('--positives_glob', type=str, + default='models/positives/gc_ftpr_*.pt,models/positives/distill_gc_*.pt') + ap.add_argument('--negatives_glob', type=str, default='models/negatives/negative_gc_*.pt') + ap.add_argument('--out_plot', type=str, default='plots/enzymes_gc_aruc.png') + ap.add_argument('--out_plot_kl', type=str, default='plots/enzymes_gc_kl.png') + + ap.add_argument('--save_csv', type=str, default='', + help='Optional: path to save thresholds/robustness/uniqueness CSV') + args = ap.parse_args() + + # Dataset dims + ds = TUDataset(root="data/ENZYMES", name="ENZYMES", + use_node_attr=True, transform=NormalizeFeatures()) + in_dim = ds.num_features + num_classes = ds.num_classes + + # Load fingerprints (list of tiny graph specs) + pack = torch.load(args.fingerprints_path, map_location="cpu") + fps = pack["fingerprints"] + ver_in_dim_saved = int(pack.get("ver_in_dim", 0)) + + # Load models (target + positives + negatives) + tmeta = json.load(open(args.target_meta, "r")) + target = get_model( + tmeta.get("arch", "gsage"), in_dim, tmeta.get("hidden", 64), num_classes, + num_layers=tmeta.get("layers", 3), dropout=tmeta.get("dropout", 0.5), pool="mean" + ) + target.load_state_dict(torch.load(args.target_path, map_location="cpu")) + target.eval() + + pos_paths = list_paths_from_globs(args.positives_glob) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + models_pos = [target] + [load_model_from_pt(p, in_dim, num_classes) for p in pos_paths] + models_neg = [load_model_from_pt(n, in_dim, num_classes) for n in neg_paths] + + # Infer verifier input dim from a probe concat + z0 = concat_for_model(models_pos[0], fps) + D = z0.numel() + if ver_in_dim_saved and ver_in_dim_saved != D: + raise RuntimeError(f"Verifier input mismatch: D={D} vs ver_in_dim_saved={ver_in_dim_saved}") + + # Load verifier + V = FPVerifier(D) + ver_path = Path(args.verifier_path) + if ver_path.exists(): + V.load_state_dict(torch.load(str(ver_path), map_location='cpu')) + print(f"Loaded verifier from {ver_path}") + elif "verifier" in pack: + V.load_state_dict(pack["verifier"]) + print("Loaded verifier from fingerprints pack.") + else: + raise FileNotFoundError( + f"No verifier found at {args.verifier_path} and no 'verifier' key in {args.fingerprints_path}" + ) + V.eval() + + # Collect scores + with torch.no_grad(): + pos_scores = [] + for m in models_pos: + z = concat_for_model(m, fps).unsqueeze(0) + pos_scores.append(float(V(z))) + neg_scores = [] + for m in models_neg: + z = concat_for_model(m, fps).unsqueeze(0) + neg_scores.append(float(V(z))) + + pos_scores = np.array(pos_scores) + neg_scores = np.array(neg_scores) + + ts = np.linspace(0.0, 1.0, 201) + robustness = np.array([(pos_scores >= t).mean() for t in ts]) # TPR on positives + uniqueness = np.array([(neg_scores < t).mean() for t in ts]) # TNR on negatives + overlap = np.minimum(robustness, uniqueness) + # Accuracy at each threshold + Npos, Nneg = len(pos_scores), len(neg_scores) + acc_curve = np.array([((pos_scores >= t).sum() + (neg_scores < t).sum()) / (Npos + Nneg) + for t in ts]) + mean_test_acc = float(acc_curve.mean()) + + aruc = np.trapz(overlap, ts) + + # Best threshold (maximize min(robustness, uniqueness)) + idx_best = int(np.argmax(overlap)) + t_best = float(ts[idx_best]) + rob_best = float(robustness[idx_best]) + uniq_best = float(uniqueness[idx_best]) + acc_best = 0.5 * (rob_best + uniq_best) + + print(f"Mean Test Accuracy (avg over thresholds) = {mean_test_acc:.4f}") + print(f"Models: +{len(models_pos)} | -{len(models_neg)} | D={D}") + print(f"ARUC = {aruc:.4f}") + print(f"Best threshold = {t_best:.3f} | Robustness={rob_best:.3f} | Uniqueness={uniq_best:.3f} | Acc={acc_best:.3f}") + + if args.save_csv: + import csv + Path(os.path.dirname(args.save_csv)).mkdir(parents=True, exist_ok=True) + with open(args.save_csv, 'w', newline='') as f: + w = csv.writer(f) + w.writerow(['threshold', 'robustness', 'uniqueness', 'min_curve', 'accuracy']) + for t, r, u, s, a in zip(ts, robustness, uniqueness, shade, acc_curve): + w.writerow([f"{t:.5f}", f"{r:.6f}", f"{u:.6f}", f"{s:.6f}", f"{a:.6f}"]) + print(f"Saved CSV to {args.save_csv}") + + # ARUC Plot + os.makedirs(os.path.dirname(args.out_plot), exist_ok=True) + fig, ax = plt.subplots(figsize=(7.5, 4.8), dpi=160) + ax.set_title(f"CiteSeer link-prediction • ARUC={aruc:.3f}", fontsize=14) + ax.grid(True, which='both', linestyle=':', linewidth=0.8, alpha=0.6) + ax.plot(ts, robustness, color="#ff0000", linewidth=2.0, label="Robustness (TPR)") + ax.plot(ts, uniqueness, color="#0000ff", linestyle="--", linewidth=2.0, label="Uniqueness (TNR)") + overlap = np.minimum(robustness, uniqueness) + ax.fill_between(ts, overlap, color="#bbbbbb", alpha=0.25, label="Overlap (ARUC region)") + + # best-threshold vertical line + # ax.axvline(t_best, color="0.4", linewidth=2.0, alpha=0.6) + + ax.set_xlabel("Threshold (τ)", fontsize=12) + ax.set_ylabel("Score", fontsize=12) + ax.set_xlim(0.0, 1.0) + ax.set_ylim(0.0, 1.0) + ax.tick_params(labelsize=11) + + leg = ax.legend(loc="lower left", frameon=True, framealpha=0.85, + facecolor="white", edgecolor="0.8") + + plt.tight_layout() + plt.savefig(args.out_plot, bbox_inches="tight") + print(f"Saved plot to {args.out_plot}") + + # KL divergence Plot + pos_divs = [model_gc_kl_to_target(m, target, fps) for m in models_pos[1:]] # exclude target itself + neg_divs = [model_gc_kl_to_target(m, target, fps) for m in models_neg] + pos_divs = np.array(pos_divs); neg_divs = np.array(neg_divs) + print(f"[KL][GC] F+ mean±std = {pos_divs.mean():.4f}±{pos_divs.std():.4f} | " + f"F- mean±std = {neg_divs.mean():.4f}±{neg_divs.std():.4f}") + + os.makedirs(os.path.dirname(args.out_plot_kl), exist_ok=True) + plt.figure(figsize=(4.8, 3.2), dpi=160) + bins = 30 + plt.hist(pos_divs, bins=bins, density=True, alpha=0.35, color="r", label="Surrogate GNN") + plt.hist(neg_divs, bins=bins, density=True, alpha=0.35, color="b", label="Irrelevant GNN") + plt.title("Graph Classification") + plt.xlabel("KL Divergence"); plt.ylabel("Density") + plt.legend() + plt.tight_layout() + plt.savefig(args.out_plot_kl, bbox_inches="tight") + print(f"Saved KL plot to {args.out_plot_kl}") + +if __name__ == "__main__": + main() diff --git a/examples/graph_class/fine_tune_pirate_gc.py b/examples/graph_class/fine_tune_pirate_gc.py new file mode 100644 index 0000000..6fef77c --- /dev/null +++ b/examples/graph_class/fine_tune_pirate_gc.py @@ -0,0 +1,203 @@ +# Create positive (pirated) GC models on ENZYMES by fine-tuning / partial-retraining +# a trained target GraphSAGE GC model. + +import argparse, json, random, copy +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.datasets import TUDataset +from torch_geometric.loader import DataLoader +from torch_geometric.transforms import NormalizeFeatures + +from gsage_gc import get_model + + +def set_seed(seed: int): + random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) + + +def split_indices(n, val_ratio=0.1, test_ratio=0.1, seed=0): + g = torch.Generator().manual_seed(seed) + perm = torch.randperm(n, generator=g) + n_val = int(round(val_ratio * n)) + n_test = int(round(test_ratio * n)) + n_train = n - n_val - n_test + idx_tr = perm[:n_train].tolist() + idx_va = perm[n_train:n_train+n_val].tolist() + idx_te = perm[n_train+n_val:].tolist() + return idx_tr, idx_va, idx_te + + +def train_one_epoch(model, loader, optimizer, device): + model.train() + total_loss, total_graphs = 0.0, 0 + for batch in loader: + batch = batch.to(device) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index, batch=batch.batch) + loss = F.cross_entropy(out, batch.y) + loss.backward(); optimizer.step() + total_loss += float(loss.item()) * batch.num_graphs + total_graphs += batch.num_graphs + return total_loss / max(1, total_graphs) + + +@torch.no_grad() +def evaluate(model, loader, device): + model.eval() + total, correct, total_loss = 0, 0, 0.0 + for batch in loader: + batch = batch.to(device) + out = model(batch.x, batch.edge_index, batch=batch.batch) + loss = F.cross_entropy(out, batch.y) + pred = out.argmax(dim=-1) + correct += int((pred == batch.y).sum()) + total += batch.num_graphs + total_loss += float(loss.item()) * batch.num_graphs + acc = correct / max(1, total) + return acc, (total_loss / max(1, total)) + + +def reinit_classifier(model: nn.Module): + if not hasattr(model, "cls"): + return + m = model.cls + if hasattr(m, "reset_parameters"): + try: + m.reset_parameters(); return + except Exception: + pass + for mod in m.modules(): + if isinstance(mod, nn.Linear): + nn.init.xavier_uniform_(mod.weight) + if mod.bias is not None: + nn.init.zeros_(mod.bias) + + +def reinit_all(model: nn.Module): + for mod in model.modules(): + if hasattr(mod, "reset_parameters"): + try: + mod.reset_parameters() + except Exception: + pass + + +def freeze_all(model: nn.Module): + for p in model.parameters(): + p.requires_grad = False + + +def unfreeze_classifier(model: nn.Module): + if hasattr(model, "cls"): + for p in model.cls.parameters(): + p.requires_grad = True + + +def unfreeze_all(model: nn.Module): + for p in model.parameters(): + p.requires_grad = True + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--target_path', type=str, default='models/target_model_gc.pt') + ap.add_argument('--meta_path', type=str, default='models/target_meta_gc.json') + ap.add_argument('--epochs', type=int, default=10) # paper uses ~10 for FT/PR + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--num_variants', type=int, default=100) # round-robin across 4 kinds + ap.add_argument('--batch_size', type=int, default=64) + ap.add_argument('--val_ratio', type=float, default=0.1) + ap.add_argument('--test_ratio', type=float, default=0.1) + ap.add_argument('--out_dir', type=str, default='models/positives') + args = ap.parse_args() + + set_seed(args.seed) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + dataset = TUDataset(root='data/ENZYMES', name='ENZYMES', + use_node_attr=True, transform=NormalizeFeatures()) + in_dim = dataset.num_features + num_classes = dataset.num_classes + n = len(dataset) + idx_tr, idx_va, idx_te = split_indices(n, args.val_ratio, args.test_ratio, seed=args.seed) + train_loader = DataLoader(dataset[idx_tr], batch_size=args.batch_size, shuffle=True) + val_loader = DataLoader(dataset[idx_va], batch_size=args.batch_size, shuffle=False) + test_loader = DataLoader(dataset[idx_te], batch_size=args.batch_size, shuffle=False) + + with open(args.meta_path, 'r') as f: + meta = json.load(f) + arch = meta.get("arch", "gsage") + hidden = meta.get("hidden", 64) + layers = meta.get("layers", 3) + dropout= meta.get("dropout", 0.5) + + target = get_model(arch, in_dim, hidden, num_classes, + num_layers=layers, dropout=dropout, pool="mean").to(device) + target.load_state_dict(torch.load(args.target_path, map_location='cpu')) + target.eval() + + kinds = ["ft_last", "ft_all", "pr_last", "pr_all"] + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + + saved = [] + for i in range(args.num_variants): + kind = kinds[i % 4] + + model = get_model(arch, in_dim, hidden, num_classes, + num_layers=layers, dropout=dropout, pool="mean").to(device) + model.load_state_dict(copy.deepcopy(target.state_dict())) + + if kind == "pr_last": + reinit_classifier(model) + elif kind == "pr_all": + reinit_all(model) + + if kind in ("ft_last", "pr_last"): + freeze_all(model); unfreeze_classifier(model) + else: + unfreeze_all(model) + + optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), + lr=args.lr, weight_decay=args.wd) + + best_val, best_state = -1.0, None + for _ in range(args.epochs): + _ = train_one_epoch(model, train_loader, optimizer, device) + val_acc, _ = evaluate(model, val_loader, device) + if val_acc > best_val: + best_val = val_acc + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + + if best_state is not None: + model.load_state_dict(best_state) + + test_acc, _ = evaluate(model, test_loader, device) + + out_path = f"{args.out_dir}/gc_ftpr_{i:03d}.pt" + meta_out = { + "task": "graph_classification", + "dataset": "ENZYMES", + "arch": arch, + "hidden": hidden, + "layers": layers, + "dropout": dropout, + "pos_kind": kind, + "val_acc": float(best_val), + "test_acc": float(test_acc), + } + torch.save(model.state_dict(), out_path) + with open(out_path.replace('.pt', '.json'), 'w') as f: + json.dump(meta_out, f, indent=2) + saved.append(out_path) + print(f"[{kind}] saved {out_path} val_acc={best_val:.4f} test_acc={test_acc:.4f}") + + print(f"Total GC FT/PR positives saved: {len(saved)}") + + +if __name__ == '__main__': + main() diff --git a/examples/graph_class/fingerprint_generator_gc.py b/examples/graph_class/fingerprint_generator_gc.py new file mode 100644 index 0000000..6a193e7 --- /dev/null +++ b/examples/graph_class/fingerprint_generator_gc.py @@ -0,0 +1,246 @@ +# Fingerprint generation for GRAPH CLASSIFICATION on ENZYMES. + +import argparse, glob, json, random, time, torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import List +from torch_geometric.datasets import TUDataset +from torch_geometric.transforms import NormalizeFeatures +from torch_geometric.utils import dense_to_sparse, to_undirected + +from gsage_gc import get_model + + +def set_seed(s: int): + random.seed(s); torch.manual_seed(s) + + +def load_meta(path): + with open(path, 'r') as f: + return json.load(f) + + +def list_paths_from_globs(globs: List[str]) -> List[str]: + out = [] + for g in globs: + out.extend(glob.glob(g)) + return sorted(out) + + +class FPVerifier(nn.Module): + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid() + ) + + def forward(self, x): + return self.net(x) + + +def load_model_from_pt(pt_path: str, in_dim: int, num_classes: int): + meta = json.load(open(pt_path.replace('.pt', '.json'), 'r')) + m = get_model(meta["arch"], in_dim, meta["hidden"], num_classes, + num_layers=meta.get("layers", 3), dropout=meta.get("dropout", 0.5), pool="mean") + m.load_state_dict(torch.load(pt_path, map_location='cpu')) + m.eval() + return m, meta + + +@torch.no_grad() +def forward_on_fp(model, fp): + X = fp["X"] + A = fp["A"] + n = X.size(0) + + # binarize & symmetrize adjacency -> edge_index + A_bin = (A > 0.5).float() + A_sym = torch.maximum(A_bin, A_bin.t()) + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + idx = torch.arange(n, dtype=torch.long) + edge_index = torch.stack([idx, (idx + 1) % n], dim=0) + edge_index = to_undirected(edge_index) + + # single-graph batch vector of zeros + batch = X.new_zeros(n, dtype=torch.long) + logits = model(X, edge_index, batch=batch) + return logits.squeeze(0) + + +def concat_for_model(model, fingerprints): + vecs = [forward_on_fp(model, fp) for fp in fingerprints] + return torch.cat(vecs, dim=-1) + + +def compute_loss(models_pos, models_neg, fingerprints, V): + z_pos = [concat_for_model(m, fingerprints) for m in models_pos] + z_neg = [concat_for_model(m, fingerprints) for m in models_neg] + if not z_pos or not z_neg: + raise RuntimeError("Need both positive and negative models.") + Zp = torch.stack(z_pos) + Zn = torch.stack(z_neg) + + yp = V(Zp).clamp(1e-6, 1-1e-6) + yn = V(Zn).clamp(1e-6, 1-1e-6) + L = torch.log(yp).mean() + torch.log(1 - yn).mean() + return L, Zp, Zn + + +def feature_ascent_step(models_pos, models_neg, fingerprints, V, alpha=0.01): + # ascent on X only + for fp in fingerprints: + fp["X"].requires_grad_(True) + fp["A"].requires_grad_(False) + + L, _, _ = compute_loss(models_pos, models_neg, fingerprints, V) + grads = torch.autograd.grad( + L, [fp["X"] for fp in fingerprints], + retain_graph=False, create_graph=False, allow_unused=True + ) + with torch.no_grad(): + for fp, g in zip(fingerprints, grads): + if g is None: + g = torch.zeros_like(fp["X"]) + fp["X"].add_(alpha * g) + fp["X"].clamp_(-5.0, 5.0) + + +def edge_flip_candidates(A: torch.Tensor, budget: int): + n = A.size(0) + tri_i, tri_j = torch.triu_indices(n, n, offset=1) + scores = torch.abs(0.5 - A[tri_i, tri_j]) + order = torch.argsort(scores) + picks = order[:min(budget, len(order))] + return tri_i[picks], tri_j[picks] + + +def edge_flip_step(models_pos, models_neg, fingerprints, V, flip_k=8): + for fp in fingerprints: + A = fp["A"] + i_idx, j_idx = edge_flip_candidates(A, flip_k * 4) + + with torch.no_grad(): + base_L, _, _ = compute_loss(models_pos, models_neg, fingerprints, V) + + gains = [] + for i, j in zip(i_idx.tolist(), j_idx.tolist()): + with torch.no_grad(): + old = float(A[i, j]) + new = 1.0 - old + A[i, j] = new; A[j, i] = new + L_try, _, _ = compute_loss(models_pos, models_neg, fingerprints, V) + gains.append((float(L_try - base_L), i, j, old)) + A[i, j] = old; A[j, i] = old + + gains.sort(key=lambda x: x[0], reverse=True) + with torch.no_grad(): + for g, i, j, old in gains[:flip_k]: + A[i, j] = 1.0 - old; A[j, i] = 1.0 - old + A.clamp_(0.0, 1.0) + + +def train_verifier_step(models_pos, models_neg, fingerprints, V, opt): + L, Zp, Zn = compute_loss(models_pos, models_neg, fingerprints, V) + loss = -L # maximize L + opt.zero_grad(); loss.backward(); opt.step() + with torch.no_grad(): + yp = (V(Zp) >= 0.5).float().mean().item() + yn = (V(Zn) < 0.5).float().mean().item() + acc = 0.5 * (yp + yn) + return float(L.item()), acc + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--target_path', default='models/target_model_gc.pt') + ap.add_argument('--target_meta', default='models/target_meta_gc.json') + ap.add_argument('--positives_glob', default='models/positives/gc_ftpr_*.pt,models/positives/distill_gc_*.pt') + ap.add_argument('--negatives_glob', default='models/negatives/negative_gc_*.pt') + ap.add_argument('--out', default='fingerprints/fingerprints_gc.pt') + + ap.add_argument('--P', type=int, default=64) # number of fingerprints (graphs) + ap.add_argument('--n', type=int, default=32) # nodes per fingerprint graph + ap.add_argument('--iters', type=int, default=1000) # alternations + ap.add_argument('--e1', type=int, default=1) # fingerprint updates per alternation + ap.add_argument('--e2', type=int, default=1) # verifier updates per alternation + ap.add_argument('--alpha_x', type=float, default=0.01) + ap.add_argument('--flip_k', type=int, default=8) # edge flips per fp per step + ap.add_argument('--verifier_lr', type=float, default=1e-3) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + + t0 = time.time() + set_seed(args.seed) + Path('fingerprints').mkdir(parents=True, exist_ok=True) + + # Dataset dims for model reconstruction + ds = TUDataset(root='data/ENZYMES', name='ENZYMES', + use_node_attr=True, transform=NormalizeFeatures()) + in_dim = ds.num_features + num_classes = ds.num_classes + + meta_t = load_meta(args.target_meta) + target = get_model(meta_t.get("arch", "gsage"), in_dim, meta_t.get("hidden", 64), num_classes, + num_layers=meta_t.get("layers", 3), dropout=meta_t.get("dropout", 0.5), pool="mean") + target.load_state_dict(torch.load(args.target_path, map_location='cpu')) + target.eval() + + pos_paths = list_paths_from_globs([g.strip() for g in args.positives_glob.split(',') if g.strip()]) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + models_pos = [target] + [load_model_from_pt(p, in_dim, num_classes)[0] for p in pos_paths] + models_neg = [load_model_from_pt(npath, in_dim, num_classes)[0] for npath in neg_paths] + + print(f"[loaded] positives={len(models_pos)} (incl. target) | negatives={len(models_neg)}") + + # Initialize fingerprints: small random X, A near 0.5, symmetric + fingerprints = [] + for _ in range(args.P): + X = torch.randn(args.n, in_dim) * 0.1 + A = torch.rand(args.n, args.n) * 0.2 + 0.4 + A = torch.triu(A, diagonal=1); A = A + A.t() + torch.diagonal(A).zero_() + fingerprints.append({"X": X, "A": A}) + + # Univerifier + ver_in_dim = args.P * num_classes + V = FPVerifier(ver_in_dim) + optV = torch.optim.Adam(V.parameters(), lr=args.verifier_lr) + + flag = 0 + for it in range(1, args.iters + 1): + if flag == 0: + for _ in range(args.e1): + feature_ascent_step(models_pos, models_neg, fingerprints, V, alpha=args.alpha_x) + edge_flip_step(models_pos, models_neg, fingerprints, V, flip_k=args.flip_k) + flag = 1 + else: + diag_acc = None + for _ in range(args.e2): + Lval, acc = train_verifier_step(models_pos, models_neg, fingerprints, V, optV) + diag_acc = acc + flag = 0 + if it % 10 == 0 and 'diag_acc' in locals() and diag_acc is not None: + print(f"[Iter {it}] verifier acc={diag_acc:.3f}") + + # Save + clean_fps = [{"X": fp["X"].detach().clone(), "A": fp["A"].detach().clone()} for fp in fingerprints] + torch.save( + {"fingerprints": clean_fps, "verifier": V.state_dict(), "ver_in_dim": ver_in_dim}, + args.out + ) + print(f"Saved {args.out}") + print("Time taken (min): ", (time.time() - t0) / 60.0) + + +if __name__ == '__main__': + main() diff --git a/examples/graph_class/generate_univerifier_dataset_gc.py b/examples/graph_class/generate_univerifier_dataset_gc.py new file mode 100644 index 0000000..69d9e2e --- /dev/null +++ b/examples/graph_class/generate_univerifier_dataset_gc.py @@ -0,0 +1,114 @@ +""" +Build a Univerifier dataset from saved GC fingerprints on ENZYMES. +Label 1 for positives ({target ∪ F+}) and 0 for negatives (F−). +""" + +import argparse, glob, json, torch +from pathlib import Path +from torch_geometric.datasets import TUDataset +from torch_geometric.utils import dense_to_sparse, to_undirected +from torch_geometric.transforms import NormalizeFeatures + +from gsage_gc import get_model + + +def list_paths_from_globs(globs_str): + globs = [g.strip() for g in globs_str.split(",") if g.strip()] + paths = [] + for g in globs: + paths.extend(glob.glob(g)) + return sorted(paths) + + +def load_model_from_pt(pt_path, in_dim, num_classes): + meta_path = pt_path.replace(".pt", ".json") + j = json.load(open(meta_path, "r")) + m = get_model( + j["arch"], in_dim, j["hidden"], num_classes, + num_layers=j.get("layers", 3), dropout=j.get("dropout", 0.5), pool="mean" + ) + m.load_state_dict(torch.load(pt_path, map_location="cpu")) + m.eval() + return m + + +@torch.no_grad() +def forward_on_fp(model, fp): + X = fp["X"] + A = fp["A"] + n = X.size(0) + + A_bin = (A > 0.5).float() + A_sym = torch.maximum(A_bin, A_bin.t()) + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + idx = torch.arange(n, dtype=torch.long) + edge_index = torch.stack([idx, (idx + 1) % n], dim=0) + edge_index = to_undirected(edge_index) + + batch = X.new_zeros(n, dtype=torch.long) + logits = model(X, edge_index, batch=batch) + return logits.squeeze(0) + + +@torch.no_grad() +def concat_for_model(model, fps): + parts = [forward_on_fp(model, fp) for fp in fps] + return torch.cat(parts, dim=0) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--fingerprints_path", type=str, default="fingerprints/fingerprints_gc.pt") + ap.add_argument("--target_path", type=str, default="models/target_model_gc.pt") + ap.add_argument("--target_meta", type=str, default="models/target_meta_gc.json") + ap.add_argument("--positives_glob", type=str, + default="models/positives/gc_ftpr_*.pt,models/positives/distill_gc_*.pt") + ap.add_argument("--negatives_glob", type=str, default="models/negatives/negative_gc_*.pt") + ap.add_argument("--out", type=str, default="fingerprints/univerifier_dataset_gc.pt") + args = ap.parse_args() + + ds = TUDataset(root="data/ENZYMES", name="ENZYMES", use_node_attr=True, transform=NormalizeFeatures()) + in_dim = ds.num_features + num_classes = ds.num_classes + + pack = torch.load(args.fingerprints_path, map_location="cpu") + fps = pack["fingerprints"] + ver_in_dim_saved = pack.get("ver_in_dim", None) + + tmeta = json.load(open(args.target_meta, "r")) + target = get_model( + tmeta.get("arch", "gsage"), in_dim, tmeta.get("hidden", 64), num_classes, + num_layers=tmeta.get("layers", 3), dropout=tmeta.get("dropout", 0.5), pool="mean" + ) + target.load_state_dict(torch.load(args.target_path, map_location="cpu")) + target.eval() + + pos_paths = list_paths_from_globs(args.positives_glob) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + models = [target] + [load_model_from_pt(p, in_dim, num_classes) for p in pos_paths] + \ + [load_model_from_pt(n, in_dim, num_classes) for n in neg_paths] + labels = [1.0] * (1 + len(pos_paths)) + [0.0] * len(neg_paths) + + # Build feature matrix X and labels y + with torch.no_grad(): + z0 = concat_for_model(models[0], fps) + D = z0.numel() + if ver_in_dim_saved is not None and D != int(ver_in_dim_saved): + raise RuntimeError( + f"Verifier input mismatch: dataset dim {D} vs saved ver_in_dim {ver_in_dim_saved}" + ) + + X_rows = [z0] + [concat_for_model(m, fps) for m in models[1:]] + X = torch.stack(X_rows, dim=0).float() + y = torch.tensor(labels, dtype=torch.float32) + + Path(Path(args.out).parent).mkdir(parents=True, exist_ok=True) + torch.save({"X": X, "y": y}, args.out) + print(f"Saved {args.out} with {X.shape[0]} rows; dim={X.shape[1]}") + print(f"Positives: {int(sum(labels))} | Negatives: {len(labels) - int(sum(labels))}") + + +if __name__ == "__main__": + main() diff --git a/examples/graph_class/gsage_gc.py b/examples/graph_class/gsage_gc.py new file mode 100644 index 0000000..c9bb957 --- /dev/null +++ b/examples/graph_class/gsage_gc.py @@ -0,0 +1,59 @@ +# Graph classification (GC) model for ENZYMES using GraphSAGE. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import SAGEConv, global_mean_pool + + +class GraphSAGE_GC(nn.Module): + def __init__(self, in_dim: int, hidden: int, num_classes: int, + num_layers: int = 3, dropout: float = 0.5, pool: str = "mean"): + super().__init__() + assert num_layers >= 1 + self.dropout = dropout + self.pool = pool + + convs = [SAGEConv(in_dim, hidden)] + for _ in range(num_layers - 1): + convs.append(SAGEConv(hidden, hidden)) + self.convs = nn.ModuleList(convs) + + self.cls = nn.Linear(hidden, num_classes) + self.reset_parameters() + + def reset_parameters(self): + for m in self.convs: + if hasattr(m, "reset_parameters"): + m.reset_parameters() + nn.init.xavier_uniform_(self.cls.weight) + if self.cls.bias is not None: + nn.init.zeros_(self.cls.bias) + + def _pool(self, x, batch): + if self.pool == "mean": + return global_mean_pool(x, batch) + return global_mean_pool(x, batch) # extend to "sum"/"max" if needed + + def forward(self, x, edge_index, batch=None): + if batch is None: + batch = x.new_zeros(x.size(0), dtype=torch.long) + + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i != len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + + g = self._pool(x, batch) + out = self.cls(g) + return out + + +def get_model(arch: str, in_dim: int, hidden: int, num_classes: int, + num_layers: int = 3, dropout: float = 0.5, pool: str = "mean"): + a = arch.lower().strip() + if a in ("graphsage", "sage", "gsage"): + return GraphSAGE_GC(in_dim, hidden, num_classes, + num_layers=num_layers, dropout=dropout, pool=pool) + raise ValueError(f"Unsupported arch for graph classification: {arch}") diff --git a/examples/graph_class/train_gc.py b/examples/graph_class/train_gc.py new file mode 100644 index 0000000..24d995c --- /dev/null +++ b/examples/graph_class/train_gc.py @@ -0,0 +1,171 @@ +# Graph classification on ENZYMES using GraphSAGE. + +import argparse +import json +import os +import random +from torch_geometric.transforms import NormalizeFeatures + +import torch +import torch.nn.functional as F +from torch_geometric.datasets import TUDataset +from torch_geometric.loader import DataLoader + +from gsage_gc import get_model + + +def set_seed(seed: int): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +from collections import defaultdict + +def split_indices_stratified(y, val_ratio=0.1, test_ratio=0.1, seed=0): + g = torch.Generator().manual_seed(seed) + by_cls = defaultdict(list) + for i, yi in enumerate(y.tolist()): + by_cls[int(yi)].append(i) + tr, va, te = [], [], [] + for cls, idxs in by_cls.items(): + idxs = torch.tensor(idxs)[torch.randperm(len(idxs), generator=g)].tolist() + n = len(idxs) + n_val = int(round(val_ratio * n)) + n_test = int(round(test_ratio * n)) + n_train = n - n_val - n_test + tr += idxs[:n_train] + va += idxs[n_train:n_train+n_val] + te += idxs[n_train+n_val:] + return tr, va, te + + +def train_one_epoch(model, loader, optimizer, device): + model.train() + total_loss = 0.0 + total_graphs = 0 + for batch in loader: + batch = batch.to(device) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index, batch=batch.batch) + loss = F.cross_entropy(out, batch.y) + loss.backward() + optimizer.step() + total_loss += float(loss.item()) * batch.num_graphs + total_graphs += batch.num_graphs + return total_loss / max(1, total_graphs) + + +@torch.no_grad() +def evaluate(model, loader, device): + model.eval() + correct = 0 + total = 0 + total_loss = 0.0 + for batch in loader: + batch = batch.to(device) + out = model(batch.x, batch.edge_index, batch=batch.batch) + loss = F.cross_entropy(out, batch.y) + pred = out.argmax(dim=-1) + correct += int((pred == batch.y).sum()) + total += batch.num_graphs + total_loss += float(loss.item()) * batch.num_graphs + acc = correct / max(1, total) + avg_loss = total_loss / max(1, total) + return acc, avg_loss + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--arch", default="gsage", choices=["gsage", "graphsage", "sage"]) + ap.add_argument("--hidden", type=int, default=64) + ap.add_argument("--layers", type=int, default=3) + ap.add_argument("--dropout", type=float, default=0.5) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--epochs", type=int, default=200) + ap.add_argument("--weight_decay", type=float, default=5e-4) + ap.add_argument("--batch_size", type=int, default=64) + ap.add_argument("--val_ratio", type=float, default=0.1) + ap.add_argument("--test_ratio", type=float, default=0.1) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + args = ap.parse_args() + + set_seed(args.seed) + device = torch.device(args.device) + + # --- Dataset: ENZYMES (graph classification) --- + dataset = TUDataset(root="data/ENZYMES", name="ENZYMES", use_node_attr=True, transform=NormalizeFeatures()) + num_graphs = len(dataset) + in_dim = dataset.num_features + num_classes = dataset.num_classes + + # split indices + y_all = torch.tensor([data.y.item() for data in dataset]) + tr_idx, va_idx, te_idx = split_indices_stratified(y_all, args.val_ratio, args.test_ratio, args.seed) + train_set = dataset[tr_idx] + val_set = dataset[va_idx] + test_set = dataset[te_idx] + + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) + val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) + test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) + + # --- Model --- + model = get_model( + args.arch, + in_dim=in_dim, + hidden=args.hidden, + num_classes=num_classes, + num_layers=args.layers, + dropout=args.dropout, + pool="mean", + ).to(device) + + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + os.makedirs("models", exist_ok=True) + best_val_acc = 0.0 + best_state = None + + for epoch in range(1, args.epochs + 1): + train_loss = train_one_epoch(model, train_loader, opt, device) + val_acc, val_loss = evaluate(model, val_loader, device) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + + if epoch % 10 == 0 or epoch == args.epochs: + print( + f"Epoch {epoch:03d} | train loss {train_loss:.4f} | " + f"val acc {val_acc:.4f} | val loss {val_loss:.4f}" + ) + + if best_state is not None: + model.load_state_dict(best_state) + + test_acc, test_loss = evaluate(model, test_loader, device) + print(f"Best Val Acc: {best_val_acc:.4f} | Test Acc: {test_acc:.4f} | Test Loss: {test_loss:.4f}") + + # Save target GC model + meta (GC-specific filenames) + torch.save(model.state_dict(), "models/target_model_gc.pt") + with open("models/target_meta_gc.json", "w") as f: + json.dump( + { + "task": "graph_classification", + "dataset": "ENZYMES", + "arch": args.arch, + "hidden": args.hidden, + "layers": args.layers, + "dropout": args.dropout, + "batch_size": args.batch_size, + "metrics": {"val_acc": float(best_val_acc), "test_acc": float(test_acc)}, + }, + f, + indent=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/graph_class/train_unrelated_gc.py b/examples/graph_class/train_unrelated_gc.py new file mode 100644 index 0000000..d55b9d7 --- /dev/null +++ b/examples/graph_class/train_unrelated_gc.py @@ -0,0 +1,154 @@ +# Train NEGATIVE (unrelated) GRAPH-CLASSIFICATION models on ENZYMES from scratch. + +import argparse +import json +import os +import random +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch_geometric.datasets import TUDataset +from torch_geometric.loader import DataLoader +from torch_geometric.transforms import NormalizeFeatures + +from gsage_gc import get_model + + +def set_seed(seed: int): + random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) + + +def split_indices(n, val_ratio=0.1, test_ratio=0.1, seed=0): + g = torch.Generator().manual_seed(seed) + perm = torch.randperm(n, generator=g) + n_val = int(round(val_ratio * n)) + n_test = int(round(test_ratio * n)) + n_train = n - n_val - n_test + idx_tr = perm[:n_train].tolist() + idx_va = perm[n_train:n_train + n_val].tolist() + idx_te = perm[n_train + n_val:].tolist() + return idx_tr, idx_va, idx_te + + +def train_one_epoch(model, loader, optimizer, device): + model.train() + total_loss, total_graphs = 0.0, 0 + for batch in loader: + batch = batch.to(device) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index, batch=batch.batch) + loss = F.cross_entropy(out, batch.y) + loss.backward() + optimizer.step() + total_loss += float(loss.item()) * batch.num_graphs + total_graphs += batch.num_graphs + return total_loss / max(1, total_graphs) + + +@torch.no_grad() +def evaluate(model, loader, device): + model.eval() + total, correct, total_loss = 0, 0, 0.0 + for batch in loader: + batch = batch.to(device) + out = model(batch.x, batch.edge_index, batch=batch.batch) + loss = F.cross_entropy(out, batch.y) + pred = out.argmax(dim=-1) + correct += int((pred == batch.y).sum()) + total += batch.num_graphs + total_loss += float(loss.item()) * batch.num_graphs + acc = correct / max(1, total) + return acc, (total_loss / max(1, total)) + + +def main(): + ap = argparse.ArgumentParser(description="Train unrelated GC (negative) models on ENZYMES") + ap.add_argument('--count', type=int, default=150) + ap.add_argument('--archs', type=str, default='gsage') + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--hidden', type=int, default=64) + ap.add_argument('--layers', type=int, default=3) + ap.add_argument('--dropout', type=float, default=0.5) + ap.add_argument('--batch_size', type=int, default=64) + ap.add_argument('--val_ratio', type=float, default=0.1) + ap.add_argument('--test_ratio', type=float, default=0.1) + ap.add_argument('--seed', type=int, default=123) + ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') + ap.add_argument('--start_index', type=int, default=50) + + args = ap.parse_args() + + device = torch.device(args.device) + Path("models/negatives").mkdir(parents=True, exist_ok=True) + + dataset_full = TUDataset(root='data/ENZYMES', name='ENZYMES', + use_node_attr=True, transform=NormalizeFeatures()) + in_dim = dataset_full.num_features + num_classes = dataset_full.num_classes + + arch_list = [a.strip() for a in args.archs.split(',') if a.strip()] + saved = [] + + for i in range(args.count): + idx = args.start_index + i + seed_i = args.seed + idx + set_seed(seed_i) + + n_graphs = len(dataset_full) + tr_idx, va_idx, te_idx = split_indices(n_graphs, args.val_ratio, args.test_ratio, seed=seed_i) + train_set = dataset_full[tr_idx] + val_set = dataset_full[va_idx] + test_set = dataset_full[te_idx] + + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) + val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False) + test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) + + arch = arch_list[idx % len(arch_list)] + model = get_model(arch, in_dim, args.hidden, num_classes, + num_layers=args.layers, dropout=args.dropout, pool="mean").to(device) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) + + best_val, best_state = -1.0, None + for ep in range(1, args.epochs + 1): + _ = train_one_epoch(model, train_loader, opt, device) + val_acc, _ = evaluate(model, val_loader, device) + if val_acc > best_val: + best_val = val_acc + best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} + + if ep % 20 == 0 or ep == args.epochs: + print(f"[neg {idx:03d} | {arch}] epoch {ep:03d} | val acc {val_acc:.4f}") + + if best_state is not None: + model.load_state_dict(best_state) + + test_acc, test_loss = evaluate(model, test_loader, device) + + out_path = f"models/negatives/negative_gc_{idx:03d}.pt" + torch.save(model.state_dict(), out_path) + meta = { + "task": "graph_classification", + "dataset": "ENZYMES", + "arch": arch, + "hidden": args.hidden, + "layers": args.layers, + "dropout": args.dropout, + "seed": seed_i, + "val_acc": float(best_val), + "test_acc": float(test_acc), + "test_loss": float(test_loss), + } + with open(out_path.replace('.pt', '.json'), 'w') as f: + json.dump(meta, f, indent=2) + + saved.append(out_path) + print(f"Saved NEGATIVE {idx:03d} arch={arch} best_val_acc={best_val:.4f} " + f"test_acc={test_acc:.4f} -> {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/link_pred/distill_students_lp.py b/examples/link_pred/distill_students_lp.py new file mode 100644 index 0000000..caabc91 --- /dev/null +++ b/examples/link_pred/distill_students_lp.py @@ -0,0 +1,146 @@ +# Distill LINK PREDICTION students on CiteSeer from a trained LP teacher +# Teacher/Student: encoder (GCN/SAGE/GAT) + dot-product decoder + +import argparse, json, random, torch +from pathlib import Path +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import subgraph, negative_sampling +from gcn_lp import get_encoder, DotProductDecoder + + +def set_seed(s: int): + random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) + + +def sample_node_subset(num_nodes: int, low: float = 0.5, high: float = 0.8): + k = max(2, int(random.uniform(low, high) * num_nodes)) + idx = torch.randperm(num_nodes)[:k] + return idx.sort().values + + +@torch.no_grad() +def teacher_edge_logits(teacher_enc, teacher_dec, x, edge_index, pos_edge, neg_edge, device): + teacher_enc.eval() + z_t = teacher_enc(x.to(device), edge_index.to(device)) + t_pos = teacher_dec(z_t, pos_edge.to(device)) + t_neg = teacher_dec(z_t, neg_edge.to(device)) + return t_pos.detach(), t_neg.detach() + + +def kd_loss(student_logits, teacher_logits, kind: str = "mse"): + if kind == "mse": + return F.mse_loss(student_logits, teacher_logits) + elif kind == "bce_soft": + with torch.no_grad(): + soft = torch.sigmoid(teacher_logits) + return F.binary_cross_entropy_with_logits(student_logits, soft) + else: + raise ValueError(f"Unknown distill loss kind: {kind}") + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--meta_path', default='models/target_meta_lp.json') + ap.add_argument('--target_path', default='models/target_model_lp.pt') + ap.add_argument('--archs', default='gat,sage') + ap.add_argument('--epochs', type=int, default=10) + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--count_per_arch', type=int, default=50) + ap.add_argument('--out_dir', type=str, default='models/positives') + ap.add_argument('--student_hidden', type=int, default=64) + ap.add_argument('--student_layers', type=int, default=3) + ap.add_argument('--distill_loss', choices=['mse', 'bce_soft'], default='mse') + ap.add_argument('--sub_low', type=float, default=0.5) # subgraph ratio lower bound + ap.add_argument('--sub_high', type=float, default=0.8) # subgraph ratio upper bound + args = ap.parse_args() + + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + set_seed(args.seed) + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + + with open(args.meta_path, 'r') as f: + meta = json.load(f) + arch_t = meta.get('arch', 'gcn') + hidden_t = meta.get('hidden', 64) + layers_t = meta.get('layers', 3) + + dataset = Planetoid(root='data', name='CiteSeer') + data = dataset[0] + + teacher_enc = get_encoder(arch_t, dataset.num_node_features, hidden_t, + num_layers=layers_t, dropout=0.5) + teacher_enc.load_state_dict(torch.load(args.target_path, map_location='cpu')) + teacher_enc.to(device).eval() + t_dec = DotProductDecoder().to(device) + + archs = [a.strip() for a in args.archs.split(',') if a.strip()] + saved = [] + + for arch in archs: + for i in range(args.count_per_arch): + student = get_encoder(arch, dataset.num_node_features, args.student_hidden, + num_layers=args.student_layers, dropout=0.5).to(device) + s_dec = DotProductDecoder().to(device) + opt = torch.optim.Adam(student.parameters(), lr=args.lr, weight_decay=args.wd) + + for _ in range(args.epochs): + student.train(); opt.zero_grad() + + # sample a subgraph (50–80% nodes by default) + idx = sample_node_subset(data.num_nodes, args.sub_low, args.sub_high) + e_idx, _ = subgraph(idx, data.edge_index, relabel_nodes=True) + if e_idx.numel() == 0 or e_idx.size(1) == 0: + continue + + x_sub = data.x[idx] + + # positives = subgraph edges; negatives = sampled non-edges + pos_edge = e_idx + neg_edge = negative_sampling( + edge_index=pos_edge, + num_nodes=x_sub.size(0), + num_neg_samples=pos_edge.size(1), + method='sparse' + ) + + t_pos, t_neg = teacher_edge_logits( + teacher_enc, t_dec, x_sub, e_idx, pos_edge, neg_edge, device + ) + + z_s = student(x_sub.to(device), e_idx.to(device)) + s_pos = s_dec(z_s, pos_edge.to(device)) + s_neg = s_dec(z_s, neg_edge.to(device)) + + s_all = torch.cat([s_pos, s_neg], dim=0) + t_all = torch.cat([t_pos, t_neg], dim=0) + loss = kd_loss(s_all, t_all, kind=args.distill_loss) + + loss.backward(); opt.step() + + out_pt = f'{args.out_dir}/distill_lp_{arch}_{i:03d}.pt' + torch.save(student.state_dict(), out_pt) + with open(out_pt.replace('.pt', '.json'), 'w') as f: + json.dump({ + "task": "link_prediction", + "dataset": "CiteSeer", + "arch": arch, + "hidden": args.student_hidden, + "layers": args.student_layers, + "pos_kind": "distill", + "teacher_arch": arch_t, + "teacher_hidden": hidden_t, + "teacher_layers": layers_t, + "distill_loss": args.distill_loss + }, f, indent=2) + + saved.append(out_pt) + print(f"[distill] saved {out_pt}") + + print(f"Saved {len(saved)} distilled LP positives.") + + +if __name__ == '__main__': + main() diff --git a/examples/link_pred/eval_verifier_lp.py b/examples/link_pred/eval_verifier_lp.py new file mode 100644 index 0000000..d9c651a --- /dev/null +++ b/examples/link_pred/eval_verifier_lp.py @@ -0,0 +1,218 @@ +""" +Evaluate a trained Univerifier on LP positives ({target ∪ F+}) and negatives (F−) +using saved LP fingerprints. Produces Robustness/Uniqueness, Mean Test Accuracy and ARUC. +""" + +import argparse, glob, json, math, os, torch +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import dense_to_sparse, to_undirected + +from gcn_lp import get_encoder, DotProductDecoder + +import torch.nn as nn + +class FPVerifier(nn.Module): + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid(), + ) + def forward(self, x): + return self.net(x) + + +def list_paths_from_globs(globs_str): + globs = [g.strip() for g in globs_str.split(",") if g.strip()] + paths = [] + for g in globs: + paths.extend(glob.glob(g)) + return sorted(paths) + +def get_lp_encoder(arch: str, in_dim: int, hidden: int, layers: int = 3): + return get_encoder(arch, in_dim, hidden, num_layers=layers, dropout=0.5) + +def load_encoder_from_pt(pt_path: str, in_dim: int): + meta_path = pt_path.replace(".pt", ".json") + j = json.load(open(meta_path, "r")) + enc = get_lp_encoder(j["arch"], in_dim, j["hidden"], layers=j.get("layers", 3)) + enc.load_state_dict(torch.load(pt_path, map_location="cpu")) + enc.eval() + return enc + + +@torch.no_grad() +def forward_on_fp(encoder, decoder, fp): + X = fp["X"] + A = fp["A"] + n = X.size(0) + + # Binarize & symmetrize adjacency; build undirected edge_index + A_bin = (A > 0.5).float() + A_sym = torch.maximum(A_bin, A_bin.t()) + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + idx = torch.arange(n, dtype=torch.long) + edge_index = torch.stack([idx, (idx + 1) % n], dim=0) + edge_index = to_undirected(edge_index) + + z = encoder(X, edge_index) + + sel = fp["node_idx"] + if sel.numel() == 1: + u = sel + v = torch.tensor([(sel.item() + 1) % n], dtype=torch.long) + else: + u = sel + v = torch.roll(sel, shifts=-1, dims=0) + probe_edge = torch.stack([u, v], dim=0) + + logits = decoder(z, probe_edge) + return logits + +@torch.no_grad() +def concat_for_model(encoder, decoder, fps): + parts = [forward_on_fp(encoder, decoder, fp) for fp in fps] + return torch.cat(parts, dim=0) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--fingerprints_path', type=str, default='fingerprints/fingerprints_lp.pt') + ap.add_argument('--verifier_path', type=str, default='fingerprints/univerifier_lp.pt') + ap.add_argument('--target_path', type=str, default='models/target_model_lp.pt') + ap.add_argument('--target_meta', type=str, default='models/target_meta_lp.json') + ap.add_argument('--positives_glob', type=str, + default='models/positives/lp_ftpr_*.pt,models/positives/distill_lp_*.pt') + ap.add_argument('--negatives_glob', type=str, default='models/negatives/negative_lp_*.pt') + ap.add_argument('--out_plot', type=str, default='plots/citeseer_lp_aruc.png') + ap.add_argument('--save_csv', type=str, default='', + help='Optional: path to save thresholds/robustness/uniqueness CSV') + args = ap.parse_args() + + ds = Planetoid(root="data", name="CiteSeer") + in_dim = ds.num_features + + # Load fingerprints (with node_idx) + pack = torch.load(args.fingerprints_path, map_location="cpu") + fps = pack["fingerprints"] + ver_in_dim_saved = int(pack.get("ver_in_dim", 0)) + + decoder = DotProductDecoder() + + tmeta = json.load(open(args.target_meta, "r")) + target_enc = get_lp_encoder(tmeta["arch"], in_dim, tmeta["hidden"], layers=tmeta.get("layers", 3)) + target_enc.load_state_dict(torch.load(args.target_path, map_location="cpu")) + target_enc.eval() + + pos_paths = list_paths_from_globs(args.positives_glob) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + models_pos = [target_enc] + [load_encoder_from_pt(p, in_dim) for p in pos_paths] + models_neg = [load_encoder_from_pt(n, in_dim) for n in neg_paths] + + z0 = concat_for_model(models_pos[0], decoder, fps) + D = z0.numel() + if ver_in_dim_saved and ver_in_dim_saved != D: + raise RuntimeError(f"Verifier input mismatch: D={D} vs ver_in_dim_saved={ver_in_dim_saved}") + + V = FPVerifier(D) + ver_path = Path(args.verifier_path) + if ver_path.exists(): + V.load_state_dict(torch.load(str(ver_path), map_location='cpu')) + print(f"Loaded verifier from {ver_path}") + elif "verifier" in pack: + V.load_state_dict(pack["verifier"]) + print("Loaded verifier from fingerprints pack.") + else: + raise FileNotFoundError( + f"No verifier found at {args.verifier_path} and no 'verifier' key in {args.fingerprints_path}" + ) + V.eval() + + with torch.no_grad(): + pos_scores = [] + for enc in models_pos: + z = concat_for_model(enc, decoder, fps).unsqueeze(0) # [1, D] + pos_scores.append(float(V(z))) + neg_scores = [] + for enc in models_neg: + z = concat_for_model(enc, decoder, fps).unsqueeze(0) + neg_scores.append(float(V(z))) + + pos_scores = np.array(pos_scores) + neg_scores = np.array(neg_scores) + + ts = np.linspace(0.0, 1.0, 201) + robustness = np.array([(pos_scores >= t).mean() for t in ts]) # TPR on positives + uniqueness = np.array([(neg_scores < t).mean() for t in ts]) # TNR on negatives + overlap = np.minimum(robustness, uniqueness) + # Accuracy at each threshold + Npos, Nneg = len(pos_scores), len(neg_scores) + acc_curve = np.array([((pos_scores >= t).sum() + (neg_scores < t).sum()) / (Npos + Nneg) + for t in ts]) + mean_test_acc = float(acc_curve.mean()) + + + aruc = np.trapz(overlap, ts) + + # Best threshold (maximize min(robustness, uniqueness)) + idx_best = int(np.argmax(overlap)) + t_best = float(ts[idx_best]) + rob_best = float(robustness[idx_best]) + uniq_best = float(uniqueness[idx_best]) + acc_best = 0.5 * (rob_best + uniq_best) + + print(f"Mean Test Accuracy (avg over thresholds) = {mean_test_acc:.4f}") + print(f"Models: +{len(models_pos)} | -{len(models_neg)} | D={D}") + print(f"ARUC = {aruc:.4f}") + print(f"Best threshold = {t_best:.3f} | Robustness={rob_best:.3f} | Uniqueness={uniq_best:.3f} | Acc={acc_best:.3f}") + + if args.save_csv: + import csv + Path(os.path.dirname(args.save_csv)).mkdir(parents=True, exist_ok=True) + with open(args.save_csv, 'w', newline='') as f: + w = csv.writer(f) + w.writerow(['threshold', 'robustness', 'uniqueness', 'min_curve', 'accuracy']) + for t, r, u, s, a in zip(ts, robustness, uniqueness, shade, acc_curve): + w.writerow([f"{t:.5f}", f"{r:.6f}", f"{u:.6f}", f"{s:.6f}", f"{a:.6f}"]) + print(f"Saved CSV to {args.save_csv}") + + # Plot + os.makedirs(os.path.dirname(args.out_plot), exist_ok=True) + fig, ax = plt.subplots(figsize=(7.5, 4.8), dpi=160) + ax.set_title(f"CiteSeer link-prediction • ARUC={aruc:.3f}", fontsize=14) + ax.grid(True, which='both', linestyle=':', linewidth=0.8, alpha=0.6) + ax.plot(ts, robustness, color="#ff0000", linewidth=2.0, label="Robustness (TPR)") + ax.plot(ts, uniqueness, color="#0000ff", linestyle="--", linewidth=2.0, label="Uniqueness (TNR)") + overlap = np.minimum(robustness, uniqueness) + ax.fill_between(ts, overlap, color="#bbbbbb", alpha=0.25, label="Overlap (ARUC region)") + + # best-threshold vertical line + # ax.axvline(t_best, color="0.4", linewidth=2.0, alpha=0.6) + + ax.set_xlabel("Threshold (τ)", fontsize=12) + ax.set_ylabel("Score", fontsize=12) + ax.set_xlim(0.0, 1.0) + ax.set_ylim(0.0, 1.0) + ax.tick_params(labelsize=11) + + leg = ax.legend(loc="lower left", frameon=True, framealpha=0.85, + facecolor="white", edgecolor="0.8") + + plt.tight_layout() + plt.savefig(args.out_plot, bbox_inches="tight") + print(f"Saved plot to {args.out_plot}") + + +if __name__ == "__main__": + main() diff --git a/examples/link_pred/fine_tune_pirate_lp.py b/examples/link_pred/fine_tune_pirate_lp.py new file mode 100644 index 0000000..7b6f95d --- /dev/null +++ b/examples/link_pred/fine_tune_pirate_lp.py @@ -0,0 +1,231 @@ +import argparse, torch, copy, random, json +from pathlib import Path +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score, average_precision_score +from torch_geometric.datasets import Planetoid +from torch_geometric.transforms import RandomLinkSplit +from torch_geometric.utils import negative_sampling + +from gcn_lp import get_encoder, DotProductDecoder + + +def set_seed(seed: int): + random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) + + +def save_model(state_dict, path, meta): + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save(state_dict, str(path)) + with open(str(path).replace('.pt', '.json'), 'w') as f: + json.dump(meta, f, indent=2) + + +def get_pos_neg_edges(d, split: str): + # positives + for name in (f"{split}_pos_edge_label_index", "pos_edge_label_index", f"{split}_pos_edge_index", "pos_edge_index"): + if hasattr(d, name): + pos = getattr(d, name) + break + else: + if hasattr(d, "edge_label_index") and hasattr(d, "edge_label"): + eli, el = d.edge_label_index, d.edge_label + pos = eli[:, el == 1] + elif split == "train" and hasattr(d, "edge_index"): + pos = d.edge_index + else: + raise AttributeError(f"No positive edges found for split='{split}'") + + # negatives + for name in (f"{split}_neg_edge_label_index", "neg_edge_label_index", f"{split}_neg_edge_index", "neg_edge_index"): + if hasattr(d, name): + neg = getattr(d, name) + break + else: + if hasattr(d, "edge_label_index") and hasattr(d, "edge_label"): + eli, el = d.edge_label_index, d.edge_label + neg = eli[:, el == 0] + else: + neg = None + + return pos, neg + + +def train_epoch_lp(encoder, decoder, data, optimizer, device): + encoder.train(); optimizer.zero_grad() + z = encoder(data.x.to(device), data.edge_index.to(device)) + + pos_e, neg_e = get_pos_neg_edges(data, "train") + if neg_e is None: + neg_e = negative_sampling( + edge_index=data.edge_index.to(device), + num_nodes=data.num_nodes, + num_neg_samples=pos_e.size(1), + method="sparse", + ) + + pos_logits = decoder(z, pos_e.to(device)) + neg_logits = decoder(z, neg_e.to(device)) + logits = torch.cat([pos_logits, neg_logits], dim=0) + labels = torch.cat( + [torch.ones(pos_logits.size(0), device=device), + torch.zeros(neg_logits.size(0), device=device)], + dim=0, + ) + loss = F.binary_cross_entropy_with_logits(logits, labels) + loss.backward(); optimizer.step() + return float(loss.item()) + + +@torch.no_grad() +def eval_split_auc_ap(encoder, decoder, data, split: str, device): + encoder.eval() + pos_e, neg_e = get_pos_neg_edges(data, split) + + z = encoder(data.x.to(device), data.edge_index.to(device)) + pos_logits = decoder(z, pos_e.to(device)) + if neg_e is None: + neg_e = negative_sampling( + edge_index=data.edge_index.to(device), + num_nodes=data.num_nodes, + num_neg_samples=pos_e.size(1), + method="sparse", + ) + neg_logits = decoder(z, neg_e.to(device)) + + logits = torch.cat([pos_logits, neg_logits], dim=0).cpu() + labels = torch.cat([torch.ones(pos_logits.size(0)), + torch.zeros(neg_logits.size(0))], dim=0) + probs = torch.sigmoid(logits) + auc = roc_auc_score(labels.numpy(), probs.numpy()) + ap = average_precision_score(labels.numpy(), probs.numpy()) + return float(auc), float(ap) + + +def freeze_all(encoder): + for p in encoder.parameters(): + p.requires_grad = False + + +def unfreeze_all(encoder): + for p in encoder.parameters(): + p.requires_grad = True + + +def unfreeze_last_gnn_layer(encoder): + if hasattr(encoder, "convs") and len(encoder.convs) > 0: + for p in encoder.convs[-1].parameters(): + p.requires_grad = True + + +def reinit_last_gnn_layer(encoder): + if hasattr(encoder, "convs") and len(encoder.convs) > 0: + m = encoder.convs[-1] + if hasattr(m, "reset_parameters"): + try: + m.reset_parameters() + except Exception: + pass + else: + for p in m.parameters(): + if p.dim() > 1: + torch.nn.init.xavier_uniform_(p) + else: + torch.nn.init.zeros_(p) + + +def reinit_all_gnn_layers(encoder): + for m in encoder.modules(): + if hasattr(m, "reset_parameters"): + try: + m.reset_parameters() + except Exception: + pass + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--target_path', type=str, default='models/target_model_lp.pt') + ap.add_argument('--meta_path', type=str, default='models/target_meta_lp.json') + ap.add_argument('--epochs', type=int, default=10) # 10 for FT/PR + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--num_variants', type=int, default=100) + ap.add_argument('--out_dir', type=str, default='models/positives') + args = ap.parse_args() + + set_seed(args.seed) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Load meta about the target LP encoder + with open(args.meta_path, 'r') as f: + meta = json.load(f) + arch = meta.get("arch", "gcn") + hidden = meta.get("hidden", 64) + layers = meta.get("layers", 3) + + # Dataset & edge-level split for LP (CiteSeer) + dataset = Planetoid(root='data', name='CiteSeer') + base_data = dataset[0] + splitter = RandomLinkSplit(num_val=0.05, num_test=0.10, is_undirected=True, add_negative_train_samples=True) + train_data, val_data, test_data = splitter(base_data) + train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device) + + target = get_encoder(arch, dataset.num_node_features, hidden, num_layers=layers, dropout=0.5) + target.load_state_dict(torch.load(args.target_path, map_location='cpu')) + target.to(device) + decoder = DotProductDecoder().to(device) + + saved = [] + kinds = ["ft_last", "ft_all", "pr_last", "pr_all"] + + for i in range(args.num_variants): + kind = kinds[i % 4] + + enc = get_encoder(arch, dataset.num_node_features, hidden, num_layers=layers, dropout=0.5) + enc.load_state_dict(copy.deepcopy(target.state_dict())) + enc.to(device) + + if kind == "pr_last": + reinit_last_gnn_layer(enc) + elif kind == "pr_all": + reinit_all_gnn_layers(enc) + + if kind in ("ft_last", "pr_last"): + freeze_all(enc); unfreeze_last_gnn_layer(enc) + else: + unfreeze_all(enc) + + opt = torch.optim.Adam(filter(lambda p: p.requires_grad, enc.parameters()), + lr=args.lr, weight_decay=args.wd) + + best_val_auc, best_state = -1.0, None + for _ in range(args.epochs): + _ = train_epoch_lp(enc, decoder, train_data, opt, device) + val_auc, val_ap = eval_split_auc_ap(enc, decoder, val_data, "val", device) + if val_auc > best_val_auc: + best_val_auc = val_auc + best_state = {k: v.detach().cpu().clone() for k, v in enc.state_dict().items()} + + enc.load_state_dict(best_state) + + out_path = f"{args.out_dir}/lp_ftpr_{i:03d}.pt" + meta_out = { + "task": "link_prediction", + "dataset": "CiteSeer", + "arch": arch, + "hidden": hidden, + "layers": layers, + "pos_kind": kind, + "val_auc": float(best_val_auc), + } + save_model(enc.state_dict(), out_path, meta_out) + saved.append(out_path) + print(f"[ftpr:{kind}] Saved {out_path} val_AUC={best_val_auc:.4f}") + + print(f"Total LP FT/PR positives saved: {len(saved)}") + + +if __name__ == '__main__': + main() diff --git a/examples/link_pred/fingerprint_generator_lp.py b/examples/link_pred/fingerprint_generator_lp.py new file mode 100644 index 0000000..9403c0f --- /dev/null +++ b/examples/link_pred/fingerprint_generator_lp.py @@ -0,0 +1,265 @@ +# Fingerprint generation & Univerifier training for LINK PREDICTION on CiteSeer. +# - loads LP encoders + dot-product decoder +# - feature vector per model = concatenated EDGE logits over P fingerprints (each contributes m logits) + + +import argparse, glob, json, math, random, time, torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import List +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import dense_to_sparse, to_undirected +from gcn_lp import get_encoder, DotProductDecoder + +def set_seed(s): + random.seed(s); torch.manual_seed(s) + +def load_meta(path): + with open(path, 'r') as f: + return json.load(f) + +def list_paths_from_globs(globs: List[str]) -> List[str]: + out = [] + for g in globs: + out.extend(glob.glob(g)) + return sorted(out) + +class FPVerifier(nn.Module): + # Arch: [128, 64, 32] + LeakyReLU, sigmoid output + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid() + ) + + def forward(self, x): + return self.net(x) + +def get_lp_encoder(arch: str, in_dim: int, hidden: int, layers: int = 3): + return get_encoder(arch, in_dim, hidden, num_layers=layers, dropout=0.5) + +def load_encoder_from_pt(pt_path: str, in_dim: int): + meta = json.load(open(pt_path.replace('.pt', '.json'), 'r')) + enc = get_lp_encoder(meta["arch"], in_dim, meta["hidden"], layers=meta.get("layers", 3)) + enc.load_state_dict(torch.load(pt_path, map_location='cpu')) + enc.eval() + return enc, meta + +# LP fingerprint forward: encoder -> embeddings -> decoder over probe edges +def forward_on_fp(encoder, decoder, fp): + X = fp["X"] + A = fp["A"] + n = X.size(0) + + A_bin = (A > 0.5).float() + A_sym = torch.maximum(A_bin, A_bin.t()) + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + idx = torch.arange(n, dtype=torch.long) + edge_index = torch.stack([idx, (idx + 1) % n], dim=0) + edge_index = to_undirected(edge_index) + + # node embeddings + z = encoder(X, edge_index) + + sel = fp["node_idx"] + if sel.numel() == 1: + u = sel + v = torch.tensor([(sel.item() + 1) % n], dtype=torch.long) + else: + u = sel + v = torch.roll(sel, shifts=-1, dims=0) + probe_edge = torch.stack([u, v], dim=0) + + logits = decoder(z, probe_edge) + return logits + +def concat_for_model(encoder, decoder, fingerprints): + vecs = [forward_on_fp(encoder, decoder, fp) for fp in fingerprints] + return torch.cat(vecs, dim=-1) + +def compute_loss(encoders_pos, encoders_neg, fingerprints, V, decoder): + z_pos = [concat_for_model(e, decoder, fingerprints) for e in encoders_pos] + z_neg = [concat_for_model(e, decoder, fingerprints) for e in encoders_neg] + if not z_pos or not z_neg: + raise RuntimeError("Need both positive and negative models.") + Zp = torch.stack(z_pos) + Zn = torch.stack(z_neg) + + yp = V(Zp).clamp(1e-6, 1-1e-6) + yn = V(Zn).clamp(1e-6, 1-1e-6) + L = torch.log(yp).mean() + torch.log(1 - yn).mean() + return L, Zp, Zn + +def feature_ascent_step(encoders_pos, encoders_neg, fingerprints, V, decoder, alpha=0.01): + # ascent on X only + for fp in fingerprints: + fp["X"].requires_grad_(True) + fp["A"].requires_grad_(False) + + L, _, _ = compute_loss(encoders_pos, encoders_neg, fingerprints, V, decoder) + grads = torch.autograd.grad( + L, [fp["X"] for fp in fingerprints], + retain_graph=False, create_graph=False, allow_unused=True + ) + with torch.no_grad(): + for fp, g in zip(fingerprints, grads): + if g is None: + g = torch.zeros_like(fp["X"]) + fp["X"].add_(alpha * g) + fp["X"].clamp_(-5.0, 5.0) + +def edge_flip_candidates(A: torch.Tensor, budget: int): + n = A.size(0) + tri_i, tri_j = torch.triu_indices(n, n, offset=1) + scores = torch.abs(0.5 - A[tri_i, tri_j]) + order = torch.argsort(scores) + picks = order[:min(budget, len(order))] + return tri_i[picks], tri_j[picks] + +def edge_flip_step(encoders_pos, encoders_neg, fingerprints, V, decoder, flip_k=8): + for fp_idx, fp in enumerate(fingerprints): + A = fp["A"] + i_idx, j_idx = edge_flip_candidates(A, flip_k * 4) # candidate pool + with torch.no_grad(): + base_L, _, _ = compute_loss(encoders_pos, encoders_neg, fingerprints, V, decoder) + + gains = [] + for i, j in zip(i_idx.tolist(), j_idx.tolist()): + with torch.no_grad(): + old = float(A[i, j]) + new = 1.0 - old + # toggle in place + A[i, j] = new; A[j, i] = new + L_try, _, _ = compute_loss(encoders_pos, encoders_neg, fingerprints, V, decoder) + gain = float(L_try - base_L) + gains.append((gain, i, j, old)) + # revert + A[i, j] = old; A[j, i] = old + + gains.sort(key=lambda x: x[0], reverse=True) + with torch.no_grad(): + for g, i, j, old in gains[:flip_k]: + new = 1.0 - old + A[i, j] = new; A[j, i] = new + A.clamp_(0.0, 1.0) + +def train_verifier_step(encoders_pos, encoders_neg, fingerprints, V, decoder, opt): + # maximize L wrt V (via minimizing -L) + L, Zp, Zn = compute_loss(encoders_pos, encoders_neg, fingerprints, V, decoder) + loss = -L + opt.zero_grad() + loss.backward() + opt.step() + with torch.no_grad(): + yp = (V(Zp) >= 0.5).float().mean().item() + yn = (V(Zn) < 0.5).float().mean().item() + acc = 0.5 * (yp + yn) + return float(L.item()), acc + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--target_path', default='models/target_model_lp.pt') + ap.add_argument('--target_meta', default='models/target_meta_lp.json') + ap.add_argument('--positives_glob', default='models/positives/lp_ftpr_*.pt,models/positives/distill_lp_*.pt') + ap.add_argument('--negatives_glob', default='models/negatives/negative_lp_*.pt') + + ap.add_argument('--P', type=int, default=64) # number of fingerprints + ap.add_argument('--n', type=int, default=32) # nodes per fingerprint + ap.add_argument('--iters', type=int, default=1000) # alternations + ap.add_argument('--verifier_lr', type=float, default=1e-3) + ap.add_argument('--e1', type=int, default=1) # fingerprint update epochs per alternation + ap.add_argument('--e2', type=int, default=1) # verifier update epochs per alternation + ap.add_argument('--alpha_x', type=float, default=0.01) # feature ascent step + ap.add_argument('--flip_k', type=int, default=8) # edges flipped per fp per step + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--m', type=int, default=4) # probed edges per fingerprint (via node_idx size) + args = ap.parse_args() + + start_time = time.time() + set_seed(args.seed) + Path('fingerprints').mkdir(parents=True, exist_ok=True) + + ds = Planetoid(root='data', name='CiteSeer') + in_dim = ds.num_features + + meta_t = load_meta(args.target_meta) + target_enc = get_lp_encoder(meta_t["arch"], in_dim, meta_t["hidden"], layers=meta_t.get("layers", 3)) + target_enc.load_state_dict(torch.load(args.target_path, map_location='cpu')) + target_enc.eval() + + pos_globs = [g.strip() for g in args.positives_glob.split(',') if g.strip()] + pos_paths = list_paths_from_globs(pos_globs) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + enc_pos = [target_enc] + [load_encoder_from_pt(p, in_dim)[0] for p in pos_paths] + enc_neg = [load_encoder_from_pt(npath, in_dim)[0] for npath in neg_paths] + decoder = DotProductDecoder() + + print(f"[loaded] positives={len(enc_pos)} (incl. target) | negatives={len(enc_neg)}") + + if args.m > args.n: + raise ValueError(f"--m ({args.m}) must be <= --n ({args.n})") + + fingerprints = [] + for _ in range(args.P): + X = torch.randn(args.n, in_dim) * 0.1 + A = torch.rand(args.n, args.n) * 0.2 + 0.4 + A = torch.triu(A, diagonal=1) + A = A + A.t() + torch.diagonal(A).zero_() + idx = torch.randperm(args.n)[:args.m] + fingerprints.append({"X": X, "A": A, "node_idx": idx}) + + # Univerifier + ver_in_dim = args.P * args.m + V = FPVerifier(ver_in_dim) + optV = torch.optim.Adam(V.parameters(), lr=args.verifier_lr) + + flag = 0 + for it in range(1, args.iters + 1): + if flag == 0: + # Update fingerprints (features + edges) + for _ in range(args.e1): + feature_ascent_step(enc_pos, enc_neg, fingerprints, V, decoder, alpha=args.alpha_x) + edge_flip_step(enc_pos, enc_neg, fingerprints, V, decoder, flip_k=args.flip_k) + flag = 1 + else: + # Update verifier + diag_acc = None + for _ in range(args.e2): + Lval, acc = train_verifier_step(enc_pos, enc_neg, fingerprints, V, decoder, optV) + diag_acc = acc + flag = 0 + + if it % 10 == 0 and 'diag_acc' in locals() and diag_acc is not None: + print(f"[Iter {it}] verifier acc={diag_acc:.3f} (diagnostic)") + + clean_fps = [] + for fp in fingerprints: + clean_fps.append({ + "X": fp["X"].detach().clone(), + "A": fp["A"].detach().clone(), + "node_idx": fp["node_idx"].detach().clone(), + }) + torch.save( + {"fingerprints": clean_fps, "verifier": V.state_dict(), "ver_in_dim": ver_in_dim}, + "fingerprints/fingerprints_lp.pt" + ) + + print("Saved fingerprints/fingerprints_lp.pt") + end_time = time.time() + print("Time taken: ", (end_time - start_time)/60) + + +if __name__ == '__main__': + main() diff --git a/examples/link_pred/gcn_lp.py b/examples/link_pred/gcn_lp.py new file mode 100644 index 0000000..f3018dc --- /dev/null +++ b/examples/link_pred/gcn_lp.py @@ -0,0 +1,101 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, SAGEConv, GATConv + +class MLPEncoder(nn.Module): + def __init__(self, in_dim: int, hidden: int, num_layers: int = 3, dropout: float = 0.5): + super().__init__() + layers = [] + h = hidden + if num_layers <= 1: + layers.append(nn.Linear(in_dim, h)) + else: + layers.append(nn.Linear(in_dim, h)) + for _ in range(num_layers - 2): + layers.append(nn.ReLU()) + layers.append(nn.Dropout(dropout)) + layers.append(nn.Linear(h, h)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(dropout)) + layers.append(nn.Linear(h, h)) + self.net = nn.Sequential(*layers) + self.dropout = dropout + + def forward(self, x, edge_index): + return self.net(x) + +class GCN(nn.Module): + def __init__(self, in_dim, hidden, num_layers=3, dropout=0.5): + super().__init__() + self.convs = nn.ModuleList() + self.convs.append(GCNConv(in_dim, hidden)) + for _ in range(num_layers - 1): + self.convs.append(GCNConv(hidden, hidden)) + self.dropout = dropout + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i != len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + return x # node embeddings + + +class GraphSAGE(nn.Module): + def __init__(self, in_dim, hidden, num_layers=3, dropout=0.5): + super().__init__() + self.convs = nn.ModuleList() + self.convs.append(SAGEConv(in_dim, hidden)) + for _ in range(num_layers - 1): + self.convs.append(SAGEConv(hidden, hidden)) + self.dropout = dropout + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i != len(self.convs) - 1: + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + return x + + +class GAT(nn.Module): + def __init__(self, in_dim, hidden, num_layers=3, heads=2, dropout=0.5): + super().__init__() + self.convs = nn.ModuleList() + self.convs.append(GATConv(in_dim, hidden, heads=heads, concat=False)) + for _ in range(num_layers - 1): + self.convs.append(GATConv(hidden, hidden, heads=heads, concat=False)) + self.dropout = dropout + + def forward(self, x, edge_index): + for i, conv in enumerate(self.convs): + x = conv(x, edge_index) + if i != len(self.convs) - 1: + x = F.elu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + return x + + +def get_encoder(arch: str, in_dim: int, hidden: int, num_layers: int = 3, dropout: float = 0.5): + arch = arch.lower() + if arch == "gcn": + return GCN(in_dim, hidden, num_layers=num_layers, dropout=dropout) + if arch in ("sage", "graphsage"): + return GraphSAGE(in_dim, hidden, num_layers=num_layers, dropout=dropout) + if arch == "gat": + return GAT(in_dim, hidden, num_layers=num_layers, dropout=dropout) + raise ValueError(f"Unknown arch: {arch}") + + +# Decoder for link prediction +class DotProductDecoder(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, z, edge_index): + # z: node embeddings [N, d] + src, dst = edge_index + return (z[src] * z[dst]).sum(dim=-1) # logits for edges diff --git a/examples/link_pred/generate_univerifier_dataset_lp.py b/examples/link_pred/generate_univerifier_dataset_lp.py new file mode 100644 index 0000000..12fc43e --- /dev/null +++ b/examples/link_pred/generate_univerifier_dataset_lp.py @@ -0,0 +1,127 @@ +""" +Build a Univerifier dataset from saved LP fingerprints. +Label 1 for positives ({target ∪ F+}) and 0 for negatives (F−). +Outputs a .pt with: + - X: [N_models, D] where D = P * m (m = probed edges per fingerprint) + - y: [N_models] float tensor with 1.0 (positive) or 0.0 (negative) +""" + +import argparse, glob, json, torch +from pathlib import Path +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import dense_to_sparse, to_undirected + +from gcn_lp import get_encoder, DotProductDecoder + + +def list_paths_from_globs(globs_str): + globs = [g.strip() for g in globs_str.split(",") if g.strip()] + paths = [] + for g in globs: + paths.extend(glob.glob(g)) + return sorted(paths) + + +def get_lp_encoder(arch: str, in_dim: int, hidden: int, layers: int = 3): + return get_encoder(arch, in_dim, hidden, num_layers=layers, dropout=0.5) + + +def load_encoder_from_pt(pt_path: str, in_dim: int): + meta_path = pt_path.replace(".pt", ".json") + j = json.load(open(meta_path, "r")) + enc = get_lp_encoder(j["arch"], in_dim, j["hidden"], layers=j.get("layers", 3)) + enc.load_state_dict(torch.load(pt_path, map_location="cpu")) + enc.eval() + return enc + + +# LP fingerprint forward: encoder -> embeddings -> dot-product over probe edges +@torch.no_grad() +def forward_on_fp(encoder, decoder, fp): + X = fp["X"] + A = fp["A"] + n = X.size(0) + + # Binarize & symmetrize adjacency, build edge_index + A_bin = (A > 0.5).float() + A_sym = torch.maximum(A_bin, A_bin.t()) + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + idx = torch.arange(n, dtype=torch.long) + edge_index = torch.stack([idx, (idx + 1) % n], dim=0) + edge_index = to_undirected(edge_index) + + z = encoder(X, edge_index) + sel = fp["node_idx"] + if sel.numel() == 1: + u = sel + v = torch.tensor([(sel.item() + 1) % n], dtype=torch.long) + else: + u = sel + v = torch.roll(sel, shifts=-1, dims=0) + probe_edge = torch.stack([u, v], dim=0) + + logits = decoder(z, probe_edge) + return logits + + +@torch.no_grad() +def concat_for_model(encoder, decoder, fps): + parts = [forward_on_fp(encoder, decoder, fp) for fp in fps] + return torch.cat(parts, dim=0) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--fingerprints_path", type=str, default="fingerprints/fingerprints_lp.pt") + ap.add_argument("--target_path", type=str, default="models/target_model_lp.pt") + ap.add_argument("--target_meta", type=str, default="models/target_meta_lp.json") + ap.add_argument("--positives_glob", type=str, + default="models/positives/lp_ftpr_*.pt,models/positives/distill_lp_*.pt") + ap.add_argument("--negatives_glob", type=str, default="models/negatives/negative_lp_*.pt") + ap.add_argument("--out", type=str, default="fingerprints/univerifier_dataset_lp.pt") + args = ap.parse_args() + + ds = Planetoid(root="data", name="CiteSeer") + in_dim = ds.num_features + + pack = torch.load(args.fingerprints_path, map_location="cpu") + fps = pack["fingerprints"] + ver_in_dim_saved = pack.get("ver_in_dim", None) + + decoder = DotProductDecoder() + + tmeta = json.load(open(args.target_meta, "r")) + target_enc = get_lp_encoder(tmeta["arch"], in_dim, tmeta["hidden"], layers=tmeta.get("layers", 3)) + target_enc.load_state_dict(torch.load(args.target_path, map_location="cpu")) + target_enc.eval() + + # Positives & negatives + pos_paths = list_paths_from_globs(args.positives_glob) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + encoders = [target_enc] + [load_encoder_from_pt(p, in_dim) for p in pos_paths] + \ + [load_encoder_from_pt(n, in_dim) for n in neg_paths] + labels = [1.0] * (1 + len(pos_paths)) + [0.0] * len(neg_paths) + + # Build feature matrix X and labels y + with torch.no_grad(): + z0 = concat_for_model(encoders[0], decoder, fps) + D = z0.numel() + if ver_in_dim_saved is not None and D != int(ver_in_dim_saved): + raise RuntimeError( + f"Verifier input mismatch: dataset dim {D} vs saved ver_in_dim {ver_in_dim_saved}" + ) + + X_rows = [z0] + [concat_for_model(enc, decoder, fps) for enc in encoders[1:]] + X = torch.stack(X_rows, dim=0).float() # [N, D] + y = torch.tensor(labels, dtype=torch.float32) # [N] + + Path(Path(args.out).parent).mkdir(parents=True, exist_ok=True) + torch.save({"X": X, "y": y}, args.out) + print(f"Saved {args.out} with {X.shape[0]} rows; dim={X.shape[1]}") + print(f"Positives: {int(sum(labels))} | Negatives: {len(labels) - int(sum(labels))}") + + +if __name__ == "__main__": + main() diff --git a/examples/link_pred/train_lp.py b/examples/link_pred/train_lp.py new file mode 100644 index 0000000..0ce62ee --- /dev/null +++ b/examples/link_pred/train_lp.py @@ -0,0 +1,183 @@ +import argparse +import json +import os +import random + +import torch +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score, average_precision_score +from torch_geometric.datasets import Planetoid +from torch_geometric.transforms import RandomLinkSplit +from torch_geometric.utils import negative_sampling + +from gcn_lp import get_encoder, DotProductDecoder + + +def set_seed(seed: int): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_pos_neg_edges(d, split: str): + # positive + if hasattr(d, f"{split}_pos_edge_label_index"): + pos = getattr(d, f"{split}_pos_edge_label_index") + elif hasattr(d, "pos_edge_label_index"): + pos = d.pos_edge_label_index + elif hasattr(d, "edge_label_index") and hasattr(d, "edge_label"): + eli, el = d.edge_label_index, d.edge_label + pos = eli[:, el == 1] + elif split == "train" and hasattr(d, "edge_index"): + pos = d.edge_index + else: + raise AttributeError(f"No positive edge indices found for split='{split}'") + + # negative (may be absent for some versions/splits) + if hasattr(d, f"{split}_neg_edge_label_index"): + neg = getattr(d, f"{split}_neg_edge_label_index") + elif hasattr(d, "neg_edge_label_index"): + neg = d.neg_edge_label_index + elif hasattr(d, "edge_label_index") and hasattr(d, "edge_label"): + eli, el = d.edge_label_index, d.edge_label + neg = eli[:, el == 0] + else: + neg = None + return pos, neg + + +def train_step(encoder, decoder, data, device): + z = encoder(data.x.to(device), data.edge_index.to(device)) + + pos_edge, neg_edge = get_pos_neg_edges(data, "train") + if neg_edge is None: + neg_edge = negative_sampling( + edge_index=data.edge_index.to(device), + num_nodes=data.num_nodes, + num_neg_samples=pos_edge.size(1), + method="sparse", + ) + + pos_logits = decoder(z, pos_edge.to(device)) + neg_logits = decoder(z, neg_edge.to(device)) + + logits = torch.cat([pos_logits, neg_logits], dim=0) + labels = torch.cat( + [torch.ones(pos_logits.size(0), device=device), + torch.zeros(neg_logits.size(0), device=device)], + dim=0, + ) + return F.binary_cross_entropy_with_logits(logits, labels) + + +@torch.no_grad() +def evaluate(encoder, decoder, data, split: str, device): + pos_edge, neg_edge = get_pos_neg_edges(data, split) + + z = encoder(data.x.to(device), data.edge_index.to(device)) + + pos_logits = decoder(z, pos_edge.to(device)) + if neg_edge is None: + neg_edge = negative_sampling( + edge_index=data.edge_index.to(device), + num_nodes=data.num_nodes, + num_neg_samples=pos_edge.size(1), + method="sparse", + ) + neg_logits = decoder(z, neg_edge.to(device)) + + logits = torch.cat([pos_logits, neg_logits], dim=0).cpu() + labels = torch.cat( + [torch.ones(pos_logits.size(0)), + torch.zeros(neg_logits.size(0))], + dim=0, + ) + probs = torch.sigmoid(logits) + auc = roc_auc_score(labels.numpy(), probs.numpy()) + ap = average_precision_score(labels.numpy(), probs.numpy()) + return auc, ap + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--arch", default="gcn", choices=["gcn", "graphsage", "sage", "gat"]) + ap.add_argument("--hidden", type=int, default=64) + ap.add_argument("--layers", type=int, default=3) + ap.add_argument("--dropout", type=float, default=0.5) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--epochs", type=int, default=200) + ap.add_argument("--weight_decay", type=float, default=5e-4) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + ap.add_argument("--val_ratio", type=float, default=0.05) + ap.add_argument("--test_ratio", type=float, default=0.10) + args = ap.parse_args() + + set_seed(args.seed) + device = torch.device(args.device) + + dataset = Planetoid(root="data", name="CiteSeer") + data = dataset[0] + + splitter = RandomLinkSplit( + num_val=args.val_ratio, + num_test=args.test_ratio, + is_undirected=True, + add_negative_train_samples=True, + ) + train_data, val_data, test_data = splitter(data) + train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device) + + encoder = get_encoder( + args.arch, + dataset.num_node_features, + hidden=args.hidden, + num_layers=args.layers, + dropout=args.dropout, + ).to(device) + decoder = DotProductDecoder().to(device) + + opt = torch.optim.Adam(encoder.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + os.makedirs("models", exist_ok=True) + best_val_auc, best_state = 0.0, None + + for epoch in range(1, args.epochs + 1): + encoder.train() + opt.zero_grad() + loss = train_step(encoder, decoder, train_data, device) + loss.backward() + opt.step() + + if epoch % 20 == 0 or epoch == args.epochs: + encoder.eval() + val_auc, val_ap = evaluate(encoder, decoder, val_data, "val", device) + if val_auc > best_val_auc: + best_val_auc = val_auc + best_state = {k: v.detach().cpu().clone() for k, v in encoder.state_dict().items()} + print(f"Epoch {epoch:03d} | loss {loss.item():.4f} | val AUC {val_auc:.4f} | val AP {val_ap:.4f}") + + if best_state is not None: + encoder.load_state_dict(best_state) + + test_auc, test_ap = evaluate(encoder, decoder, test_data, "test", device) + print(f"Best Val AUC: {best_val_auc:.4f} | Test AUC: {test_auc:.4f} | Test AP: {test_ap:.4f}") + + torch.save(encoder.state_dict(), "models/target_model_lp.pt") + with open("models/target_meta_lp.json", "w") as f: + json.dump( + { + "task": "link_prediction", + "dataset": "CiteSeer", + "arch": args.arch, + "hidden": args.hidden, + "layers": args.layers, + "metrics": {"AUC": float(test_auc), "AP": float(test_ap)}, + }, + f, + indent=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/link_pred/train_unrelated_lp.py b/examples/link_pred/train_unrelated_lp.py new file mode 100644 index 0000000..d7216ae --- /dev/null +++ b/examples/link_pred/train_unrelated_lp.py @@ -0,0 +1,199 @@ +# Train NEGATIVE LINK-PREDICTION models on CiteSeer from scratch. + +import argparse +import json +import os +import random +from pathlib import Path + +import torch +import torch.nn as nn +import torch.nn.functional as F +from sklearn.metrics import roc_auc_score, average_precision_score +from torch_geometric.datasets import Planetoid +from torch_geometric.transforms import RandomLinkSplit +from torch_geometric.utils import negative_sampling + +from gcn_lp import get_encoder, DotProductDecoder + + +def set_seed(seed: int): + random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) + + +def get_pos_neg_edges(d, split: str): + # positives + for name in (f"{split}_pos_edge_label_index", "pos_edge_label_index", + f"{split}_pos_edge_index", "pos_edge_index"): + if hasattr(d, name): + pos = getattr(d, name) + break + else: + if hasattr(d, "edge_label_index") and hasattr(d, "edge_label"): + eli, el = d.edge_label_index, d.edge_label + pos = eli[:, el == 1] + elif split == "train" and hasattr(d, "edge_index"): + pos = d.edge_index + else: + raise AttributeError(f"No positive edges found for split='{split}'") + + # negatives + for name in (f"{split}_neg_edge_label_index", "neg_edge_label_index", + f"{split}_neg_edge_index", "neg_edge_index"): + if hasattr(d, name): + neg = getattr(d, name) + break + else: + if hasattr(d, "edge_label_index") and hasattr(d, "edge_label"): + eli, el = d.edge_label_index, d.edge_label + neg = eli[:, el == 0] + else: + neg = None + + return pos, neg + + +def get_lp_encoder(arch: str, in_dim: int, hidden: int, layers: int, dropout: float): + a = arch.lower().strip() + if a in ("gcn", "sage", "graphsage", "gat"): + return get_encoder(a, in_dim, hidden, num_layers=layers, dropout=dropout) + raise ValueError(f"Unknown arch: {arch}") + + +def train_step(encoder, decoder, data, device): + z = encoder(data.x.to(device), data.edge_index.to(device)) + + pos_edge, neg_edge = get_pos_neg_edges(data, "train") + if neg_edge is None: + neg_edge = negative_sampling( + edge_index=data.edge_index.to(device), + num_nodes=data.num_nodes, + num_neg_samples=pos_edge.size(1), + method="sparse", + ) + + pos_logits = decoder(z, pos_edge.to(device)) + neg_logits = decoder(z, neg_edge.to(device)) + logits = torch.cat([pos_logits, neg_logits], dim=0) + labels = torch.cat( + [torch.ones(pos_logits.size(0), device=device), + torch.zeros(neg_logits.size(0), device=device)], + dim=0, + ) + return F.binary_cross_entropy_with_logits(logits, labels) + + +@torch.no_grad() +def evaluate(encoder, decoder, data, split: str, device): + pos_edge, neg_edge = get_pos_neg_edges(data, split) + + z = encoder(data.x.to(device), data.edge_index.to(device)) + pos_logits = decoder(z, pos_edge.to(device)) + if neg_edge is None: + neg_edge = negative_sampling( + edge_index=data.edge_index.to(device), + num_nodes=data.num_nodes, + num_neg_samples=pos_edge.size(1), + method="sparse", + ) + neg_logits = decoder(z, neg_edge.to(device)) + + logits = torch.cat([pos_logits, neg_logits], dim=0).cpu() + labels = torch.cat([torch.ones(pos_logits.size(0)), + torch.zeros(neg_logits.size(0))], dim=0) + probs = torch.sigmoid(logits) + auc = roc_auc_score(labels.numpy(), probs.numpy()) + ap = average_precision_score(labels.numpy(), probs.numpy()) + return float(auc), float(ap) + + +def main(): + ap = argparse.ArgumentParser(description="Train unrelated LP (negative) models on CiteSeer") + ap.add_argument('--count', type=int, default=100) + ap.add_argument('--archs', type=str, default='gcn,sage,gat') + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--hidden', type=int, default=64) + ap.add_argument('--layers', type=int, default=3) + ap.add_argument('--dropout', type=float, default=0.5) + ap.add_argument('--seed', type=int, default=123) + ap.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu') + ap.add_argument('--val_ratio', type=float, default=0.05) + ap.add_argument('--test_ratio', type=float, default=0.10) + ap.add_argument('--start_index', type=int, default=0) + + args = ap.parse_args() + + device = torch.device(args.device) + os.makedirs("models/negatives", exist_ok=True) + + # Dataset & edge-level split + dataset = Planetoid(root='data', name='CiteSeer') + data_full = dataset[0] + splitter = RandomLinkSplit( + num_val=args.val_ratio, + num_test=args.test_ratio, + is_undirected=True, + add_negative_train_samples=True, + ) + train_data, val_data, test_data = splitter(data_full) + train_data, val_data, test_data = train_data.to(device), val_data.to(device), test_data.to(device) + + arch_list = [a.strip() for a in args.archs.split(',') if a.strip()] + saved = [] + + for i in range(args.count): + idx = args.start_index + i + seed_i = args.seed + idx + arch = arch_list[idx % len(arch_list)] + + arch = arch_list[i % len(arch_list)] + encoder = get_lp_encoder(arch, dataset.num_node_features, args.hidden, args.layers, args.dropout).to(device) + decoder = DotProductDecoder().to(device) + + opt = torch.optim.Adam(encoder.parameters(), lr=args.lr, weight_decay=args.wd) + + best_val_auc, best_state = -1.0, None + for ep in range(1, args.epochs + 1): + encoder.train(); opt.zero_grad() + loss = train_step(encoder, decoder, train_data, device) + loss.backward(); opt.step() + + if ep % 20 == 0 or ep == args.epochs: + encoder.eval() + val_auc, val_ap = evaluate(encoder, decoder, val_data, "val", device) + if val_auc > best_val_auc: + best_val_auc = val_auc + best_state = {k: v.detach().cpu().clone() for k, v in encoder.state_dict().items()} + print(f"[neg {i:03d} | {arch}] epoch {ep:03d} | loss {loss.item():.4f} | val AUC {val_auc:.4f} | val AP {val_ap:.4f}") + + if best_state is not None: + encoder.load_state_dict(best_state) + + test_auc, test_ap = evaluate(encoder, decoder, test_data, "test", device) + + out_path = f"models/negatives/negative_lp_{idx:03d}.pt" + torch.save(encoder.state_dict(), out_path) + meta = { + "task": "link_prediction", + "dataset": "CiteSeer", + "arch": arch, + "hidden": args.hidden, + "layers": args.layers, + "dropout": args.dropout, + "seed": seed_i, + "best_val_auc": float(best_val_auc), + "test_auc": float(test_auc), + "test_ap": float(test_ap), + } + with open(out_path.replace('.pt', '.json'), 'w') as f: + json.dump(meta, f, indent=2) + + saved.append(out_path) + print(f"Saved NEGATIVE {i:03d} arch={arch} best_val_AUC={best_val_auc:.4f} " + f"test AUC={test_auc:.4f} AP={test_ap:.4f} -> {out_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/node_class/distill_students_nc.py b/examples/node_class/distill_students_nc.py new file mode 100644 index 0000000..1f26764 --- /dev/null +++ b/examples/node_class/distill_students_nc.py @@ -0,0 +1,88 @@ +import argparse, json, random, torch, torch.nn.functional as F +from pathlib import Path +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import subgraph +from gcn_nc import get_model + +def set_seed(s): + random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) + +def make_masks_like(train_seed=0): + ds = Planetoid(root='data/cora', name='Cora') + data = ds[0] + g = torch.Generator().manual_seed(train_seed) + idx = torch.randperm(data.num_nodes, generator=g) + n_tr = int(0.7*data.num_nodes); n_va = int(0.1*data.num_nodes) + tr, va, te = idx[:n_tr], idx[n_tr:n_tr+n_va], idx[n_tr+n_va:] + mtr = torch.zeros(data.num_nodes, dtype=torch.bool); mtr[tr]=True + mva = torch.zeros(data.num_nodes, dtype=torch.bool); mva[va]=True + mte = torch.zeros(data.num_nodes, dtype=torch.bool); mte[te]=True + data.train_mask, data.val_mask, data.test_mask = mtr, mva, mte + return ds, data + +@torch.no_grad() +def teacher_logits_on_nodes(model, x, edge_index, nodes): + model.eval() + out = model(x, edge_index) + return out[nodes] + +def sample_node_subgraph(num_nodes, low=0.5, high=0.8): + k = int(random.uniform(low, high) * num_nodes) + idx = torch.randperm(num_nodes)[:k] + return idx.sort().values + +def kd_loss(student_logits, teacher_logits): + return F.mse_loss(student_logits, teacher_logits) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--meta_path', default='models/target_meta_nc.json') + ap.add_argument('--target_path', default='models/target_model_nc.pt') + ap.add_argument('--archs', default='gat,sage') + ap.add_argument('--epochs', type=int, default=10) + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--count_per_arch', type=int, default=50) + ap.add_argument('--out_dir', type=str, default='models/positives') + args = ap.parse_args() + + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + + with open(args.meta_path,'r') as f: + meta = json.load(f) + ds, data = make_masks_like(train_seed=args.seed) + in_dim, num_classes = ds.num_features, ds.num_classes + + teacher = get_model(meta['arch'], in_dim, meta['hidden'], num_classes) + teacher.load_state_dict(torch.load(args.target_path, map_location='cpu')) + teacher.eval() + + archs = [a.strip() for a in args.archs.split(',') if a.strip()] + saved = [] + for arch in archs: + for i in range(args.count_per_arch): + student = get_model(arch, in_dim, 64, num_classes) + opt = torch.optim.Adam(student.parameters(), lr=args.lr, weight_decay=args.wd) + + for _ in range(args.epochs): + student.train(); opt.zero_grad() + idx = sample_node_subgraph(data.num_nodes, 0.5, 0.8) + e_idx, _ = subgraph(idx, data.edge_index, relabel_nodes=True) + x_sub = data.x[idx] + with torch.no_grad(): + t_logits = teacher_logits_on_nodes(teacher, data.x, data.edge_index, idx) + s_logits = student(x_sub, e_idx) + loss = kd_loss(s_logits, t_logits) + loss.backward(); opt.step() + + out_pt = f'{args.out_dir}/distill_nc_{arch}_{i:03d}.pt' + torch.save(student.state_dict(), out_pt) + with open(out_pt.replace('.pt','.json'),'w') as f: + json.dump({"arch": arch, "hidden": 64, "num_classes": num_classes, "pos_kind": "distill"}, f) + saved.append(out_pt) + print(f"[distill-nc] saved {out_pt}") + print(f"Saved {len(saved)} distilled positives.") + +if __name__ == '__main__': + main() diff --git a/examples/node_class/eval_verifier_nc.py b/examples/node_class/eval_verifier_nc.py new file mode 100644 index 0000000..3565ceb --- /dev/null +++ b/examples/node_class/eval_verifier_nc.py @@ -0,0 +1,257 @@ +""" +Evaluate a trained Univerifier on positives ({target ∪ F+}) and negatives (F−) +using saved fingerprints. Produces Robustness/Uniqueness, ARUC, Mean Test Accuracy, KL Divergence. +""" + +import argparse, glob, json, math, torch, os +import numpy as np +import matplotlib.pyplot as plt +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import dense_to_sparse +from gcn_nc import get_model +import torch.nn.functional as F + +import torch.nn as nn +class FPVerifier(nn.Module): + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid(), + ) + def forward(self, x): + return self.net(x) + + +@torch.no_grad() +def forward_on_fp(model, fp): + + X = fp["X"] + A = fp["A"] + idx = fp["node_idx"] + + A_bin = (A > 0.5).float() + A_sym = torch.triu(A_bin, diagonal=1) + A_sym = A_sym + A_sym.t() + edge_index = dense_to_sparse(A_sym)[0] + + if edge_index.numel() == 0: + n = X.size(0) + edge_index = torch.arange(n, dtype=torch.long).repeat(2, 1) + + logits = model(X, edge_index) + sel = logits[idx, :] + return sel.reshape(-1) + + +@torch.no_grad() +def concat_for_model(model, fps): + parts = [forward_on_fp(model, fp) for fp in fps] + return torch.cat(parts, dim=0) + + +def list_paths_from_globs(globs_str): + globs = [g.strip() for g in globs_str.split(",") if g.strip()] + paths = [] + for g in globs: + paths.extend(glob.glob(g)) + return sorted(paths) + + +def load_model_from_pt(pt_path, in_dim): + meta_path = pt_path.replace(".pt", ".json") + j = json.load(open(meta_path, "r")) + m = get_model(j["arch"], in_dim, j["hidden"], j["num_classes"]) + m.load_state_dict(torch.load(pt_path, map_location="cpu")) + m.eval() + return m + +# KL divergence helpers +def softmax_logits(x): + return F.softmax(x, dim=-1) + +@torch.no_grad() +def forward_nc_logits(model, fp): + X, A, idx = fp["X"], fp["A"], fp["node_idx"] + A_bin = (A > 0.5).float() + A_sym = torch.triu(A_bin, diagonal=1); A_sym = A_sym + A_sym.t() + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + n = X.size(0) + edge_index = torch.arange(n, dtype=torch.long).repeat(2, 1) + logits = model(X, edge_index) + return logits[idx, :] + +def sym_kl(p, q, eps=1e-8): + """ + Symmetric KL + """ + p = p.clamp(min=eps); q = q.clamp(min=eps) + kl1 = (p * (p.log() - q.log())).sum(dim=-1) + kl2 = (q * (q.log() - p.log())).sum(dim=-1) + return 0.5 * (kl1 + kl2) + +@torch.no_grad() +def model_nc_kl_to_target(suspect, target, fps): + """ + Average symmetric KL over all fingerprints. + """ + vals = [] + for fp in fps: + t = softmax_logits(forward_nc_logits(target, fp)) + s = softmax_logits(forward_nc_logits(suspect, fp)) + d = sym_kl(s, t) + vals.append(d.mean().item()) + return float(np.mean(vals)) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--fingerprints_path', type=str, default='fingerprints/fingerprints_nc.pt') + ap.add_argument('--verifier_path', type=str, default='fingerprints/univerifier_nc.pt') + ap.add_argument('--target_path', type=str, default='models/target_model_nc.pt') + ap.add_argument('--target_meta', type=str, default='models/target_meta_nc.json') + ap.add_argument('--positives_glob', type=str, + default='models/positives/nc_ftpr_*.pt,models/positives/distill_nc_*.pt') + ap.add_argument('--negatives_glob', type=str, default='models/negatives/negative_nc_*.pt') + ap.add_argument('--out_plot', type=str, default='plots/cora_nc_aruc.png') + ap.add_argument('--out_plot_kl', type=str, default='plots/cora_nc_kl.png') + ap.add_argument('--save_csv', type=str, default='', + help='Optional: path to save thresholds/robustness/uniqueness CSV') + args = ap.parse_args() + + ds = Planetoid(root="data/cora", name="Cora") + in_dim = ds.num_features + num_classes = ds.num_classes + + # Load fingerprints (with node_idx) + pack = torch.load(args.fingerprints_path, map_location="cpu") + fps = pack["fingerprints"] + ver_in_dim_saved = int(pack.get("ver_in_dim", 0)) + + # Load models (target + positives + negatives) + tmeta = json.load(open(args.target_meta, "r")) + target = get_model(tmeta["arch"], in_dim, tmeta["hidden"], tmeta["num_classes"]) + target.load_state_dict(torch.load(args.target_path, map_location="cpu")) + target.eval() + + pos_paths = list_paths_from_globs(args.positives_glob) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + models_pos = [target] + [load_model_from_pt(p, in_dim) for p in pos_paths] + models_neg = [load_model_from_pt(n, in_dim) for n in neg_paths] + + # Infer verifier input dim from a probe concat + z0 = concat_for_model(models_pos[0], fps) + D = z0.numel() + if ver_in_dim_saved and ver_in_dim_saved != D: + raise RuntimeError(f"Verifier input mismatch: D={D} vs ver_in_dim_saved={ver_in_dim_saved}") + + V = FPVerifier(D) + V.load_state_dict(torch.load(args.verifier_path, map_location='cpu')) + V.eval() + + with torch.no_grad(): + pos_scores = [] + for m in models_pos: + z = concat_for_model(m, fps).unsqueeze(0) + pos_scores.append(float(V(z))) + neg_scores = [] + for m in models_neg: + z = concat_for_model(m, fps).unsqueeze(0) + neg_scores.append(float(V(z))) + + pos_scores = np.array(pos_scores) + neg_scores = np.array(neg_scores) + + # Sweep thresholds + ts = np.linspace(0.0, 1.0, 201) + robustness = np.array([(pos_scores >= t).mean() for t in ts]) # TPR on positives + uniqueness = np.array([(neg_scores < t).mean() for t in ts]) # TNR on negatives + overlap = np.minimum(robustness, uniqueness) + + # Mean Test Accuracy at each threshold + Npos = len(pos_scores) + Nneg = len(neg_scores) + acc_curve = np.array([((pos_scores >= t).sum() + (neg_scores < t).sum()) / (Npos + Nneg) + for t in ts]) + + mean_test_acc = float(acc_curve.mean()) + + aruc = np.trapz(overlap, ts) + + idx_best = int(np.argmax(overlap)) + t_best = float(ts[idx_best]) + rob_best = float(robustness[idx_best]) + uniq_best = float(uniqueness[idx_best]) + acc_best = 0.5 * (rob_best + uniq_best) + + print(f"Mean Test Accuracy (avg over thresholds) = {mean_test_acc:.4f}") + print(f"Models: +{len(models_pos)} | -{len(models_neg)} | D={D}") + print(f"ARUC = {aruc:.4f}") + print(f"Best threshold = {t_best:.3f} | Robustness={rob_best:.3f} | Uniqueness={uniq_best:.3f} | Acc={acc_best:.3f}") + + if args.save_csv: + import csv + with open(args.save_csv, 'w', newline='') as f: + w = csv.writer(f) + w.writerow(['threshold', 'robustness', 'uniqueness', 'min_curve', 'accuracy']) + for t, r, u, s, a in zip(ts, robustness, uniqueness, overlap, acc_curve): + w.writerow([f"{t:.5f}", f"{r:.6f}", f"{u:.6f}", f"{s:.6f}", f"{a:.6f}"]) + print(f"Saved CSV to {args.save_csv}") + + # ARUC Plot + os.makedirs(os.path.dirname(args.out_plot), exist_ok=True) + fig, ax = plt.subplots(figsize=(7.5, 4.8), dpi=160) + ax.set_title(f"CiteSeer link-prediction • ARUC={aruc:.3f}", fontsize=14) + ax.grid(True, which='both', linestyle=':', linewidth=0.8, alpha=0.6) + ax.plot(ts, robustness, color="#ff0000", linewidth=2.0, label="Robustness (TPR)") + ax.plot(ts, uniqueness, color="#0000ff", linestyle="--", linewidth=2.0, label="Uniqueness (TNR)") + overlap = np.minimum(robustness, uniqueness) + ax.fill_between(ts, overlap, color="#bbbbbb", alpha=0.25, label="Overlap (ARUC region)") + + # best-threshold vertical line + # ax.axvline(t_best, color="0.4", linewidth=2.0, alpha=0.6) + + ax.set_xlabel("Threshold (τ)", fontsize=12) + ax.set_ylabel("Score", fontsize=12) + ax.set_xlim(0.0, 1.0) + ax.set_ylim(0.0, 1.0) + ax.tick_params(labelsize=11) + + leg = ax.legend(loc="lower left", frameon=True, framealpha=0.85, + facecolor="white", edgecolor="0.8") + + plt.tight_layout() + plt.savefig(args.out_plot, bbox_inches="tight") + print(f"Saved plot to {args.out_plot}") + + + # KL divergence + pos_divs = [model_nc_kl_to_target(m, target, fps) for m in models_pos[1:]] # exclude target itself + neg_divs = [model_nc_kl_to_target(m, target, fps) for m in models_neg] + pos_divs, neg_divs = np.array(pos_divs), np.array(neg_divs) + + print(f"[KL] F+ mean±std = {pos_divs.mean():.4f}±{pos_divs.std():.4f} | " + f"F- mean±std = {neg_divs.mean():.4f}±{neg_divs.std():.4f}") + + os.makedirs(os.path.dirname(args.out_plot_kl), exist_ok=True) + plt.figure(figsize=(4.8, 3.2), dpi=160) + bins = 30 + plt.hist(pos_divs, bins=bins, density=True, alpha=0.35, color="r", label="Surrogate GNN") + plt.hist(neg_divs, bins=bins, density=True, alpha=0.35, color="b", label="Irrelevant GNN") + plt.title("Node Classification") + plt.xlabel("KL Divergence"); plt.ylabel("Density") + plt.legend() + plt.tight_layout() + plt.savefig(args.out_plot_kl, bbox_inches="tight") + print(f"Saved KL plot to {args.out_plot_kl}") + + +if __name__ == "__main__": + main() diff --git a/examples/node_class/fine_tune_pirate_nc.py b/examples/node_class/fine_tune_pirate_nc.py new file mode 100644 index 0000000..fcb2cd4 --- /dev/null +++ b/examples/node_class/fine_tune_pirate_nc.py @@ -0,0 +1,116 @@ +import argparse, torch, copy, random, json +from pathlib import Path +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +from gcn_nc import get_model + +def set_seed(seed): + random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) + +def make_masks(num_nodes, train_p=0.7, val_p=0.1, seed=0): + g = torch.Generator().manual_seed(seed) + idx = torch.randperm(num_nodes, generator=g) + n_train = int(train_p * num_nodes) + n_val = int(val_p * num_nodes) + train_idx = idx[:n_train]; val_idx = idx[n_train:n_train+n_val]; test_idx = idx[n_train+n_val:] + train_mask = torch.zeros(num_nodes, dtype=torch.bool); train_mask[train_idx]=True + val_mask = torch.zeros(num_nodes, dtype=torch.bool); val_mask[val_idx]=True + test_mask = torch.zeros(num_nodes, dtype=torch.bool); test_mask[test_idx]=True + return train_mask, val_mask, test_mask + +def train_epoch(model, data, optimizer, mask): + model.train(); optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = F.cross_entropy(out[mask], data.y[mask]) + loss.backward(); optimizer.step() + return float(loss.item()) + +@torch.no_grad() +def eval_mask(model, data, mask): + model.eval(); out = model(data.x, data.edge_index) + pred = out.argmax(dim=-1) + return float((pred[mask]==data.y[mask]).float().mean()) + +def reinit_last_layer(model): + last = None + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear): + last = module + if last is not None: + for p in last.parameters(): + if p.dim() > 1: torch.nn.init.xavier_uniform_(p) + else: torch.nn.init.zeros_(p) + +def reinit_all(model): + for m in model.modules(): + if isinstance(m, torch.nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if m.bias is not None: torch.nn.init.zeros_(m.bias) + if hasattr(m, 'reset_parameters'): + try: m.reset_parameters() + except: pass + +def save_model(model, path, meta): + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + torch.save(model.state_dict(), str(path)) + with open(str(path).replace('.pt','.json'),'w') as f: + json.dump(meta, f) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--target_path', type=str, default='models/target_model_nc.pt') + ap.add_argument('--meta_path', type=str, default='models/target_meta_nc.json') + ap.add_argument('--epochs', type=int, default=10) + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--num_variants', type=int, default=100) + ap.add_argument('--out_dir', type=str, default='models/positives') + args = ap.parse_args() + + set_seed(args.seed) + with open(args.meta_path,'r') as f: + meta = json.load(f) + + dataset = Planetoid(root='data/cora', name='Cora') + data = dataset[0] + train_mask, val_mask, test_mask = make_masks(data.num_nodes, 0.7, 0.1, seed=args.seed) + data.train_mask, data.val_mask, data.test_mask = train_mask, val_mask, test_mask + + target = get_model(meta["arch"], data.num_features, meta["hidden"], meta["num_classes"]) + target.load_state_dict(torch.load(args.target_path, map_location='cpu')) + + saved = [] + for i in range(args.num_variants): + kind = i % 4 # 0:FT-last,1:FT-all,2:PR-last,3:PR-all + m = get_model(meta["arch"], data.num_features, meta["hidden"], meta["num_classes"]) + m.load_state_dict(target.state_dict()) + + if kind == 2: reinit_last_layer(m) + elif kind == 3: reinit_all(m) + + if kind in (0,2): + for name,p in m.named_parameters(): + p.requires_grad = ('conv2' in name) or ('mlp.3' in name) + else: + for p in m.parameters(): p.requires_grad=True + + opt = torch.optim.Adam(filter(lambda p: p.requires_grad, m.parameters()), lr=args.lr, weight_decay=args.wd) + best_val, best_state = -1, None + for _ in range(args.epochs): + _ = train_epoch(m, data, opt, data.train_mask) + val = eval_mask(m, data, data.val_mask) + if val > best_val: + best_val, best_state = val, {k:v.cpu().clone() for k,v in m.state_dict().items()} + m.load_state_dict(best_state) + out_path = f"{args.out_dir}/nc_ftpr_{i:03d}.pt" + meta_out = {"arch": meta["arch"], "hidden": meta["hidden"], "num_classes": meta["num_classes"], "pos_kind": ["ft_last","ft_all","pr_last","pr_all"][kind]} + save_model(m, out_path, meta_out) + saved.append(out_path) + print(f"Saved {out_path} val={best_val:.4f}") + + print(f"Total FT/PR positives saved: {len(saved)}") + +if __name__ == '__main__': + main() diff --git a/examples/node_class/fingerprint_generator_nc.py b/examples/node_class/fingerprint_generator_nc.py new file mode 100644 index 0000000..3a62216 --- /dev/null +++ b/examples/node_class/fingerprint_generator_nc.py @@ -0,0 +1,262 @@ +import argparse, glob, json, math, random, torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import List, Dict +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import dense_to_sparse, to_undirected +import time + + +def set_seed(s): + random.seed(s); torch.manual_seed(s) + +def load_meta(path): + with open(path, 'r') as f: + return json.load(f) + +def get_model(arch: str, in_dim: int, hidden: int, num_classes: int): + from gcn_nc import get_model as _get + return _get(arch, in_dim, hidden, num_classes) + +def list_paths_from_globs(globs: List[str]) -> List[str]: + out = [] + for g in globs: + out.extend(glob.glob(g)) + return sorted(out) + +class FPVerifier(nn.Module): + # Arch: [128, 64, 32] + LeakyReLU, sigmoid output + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid() + ) + + def forward(self, x): + return self.net(x) + +def load_model_from_pair(pt_path: str, in_dim: int): + meta = json.load(open(pt_path.replace('.pt', '.json'), 'r')) + m = get_model(meta["arch"], in_dim, meta["hidden"], meta["num_classes"]) + m.load_state_dict(torch.load(pt_path, map_location='cpu')) + m.eval() + return m, meta + +def forward_on_fp(model, fp): + A = fp["A"] + A_bin = (A > 0.5).float() + A_sym = torch.maximum(A_bin, A_bin.t()) + edge_index = dense_to_sparse(A_sym)[0] + if edge_index.numel() == 0: + edge_index = torch.arange(fp["X"].size(0)).repeat(2,1) + edge_index = to_undirected(edge_index) + logits = model(fp["X"], edge_index) + return logits.mean(dim=0) + +def concat_for_model(model, fingerprints): + vecs = [forward_on_fp(model, fp) for fp in fingerprints] + return torch.cat(vecs, dim=-1) + +def compute_loss(models_pos, models_neg, fingerprints, V): + z_pos = [] + for m in models_pos: + z_pos.append(concat_for_model(m, fingerprints)) + z_neg = [] + for m in models_neg: + z_neg.append(concat_for_model(m, fingerprints)) + if len(z_pos) == 0 or len(z_neg) == 0: + raise RuntimeError("Need both positive and negative models.") + Zp = torch.stack(z_pos) + Zn = torch.stack(z_neg) + + yp = V(Zp).clamp(1e-6, 1-1e-6) + yn = V(Zn).clamp(1e-6, 1-1e-6) + + L = torch.log(yp).mean() + torch.log(1 - yn).mean() + return L, Zp, Zn + +def feature_ascent_step(models_pos, models_neg, fingerprints, V, alpha=0.01): + for fp in fingerprints: + fp["X"].requires_grad_(True) + fp["A"].requires_grad_(False) + + L, _, _ = compute_loss(models_pos, models_neg, fingerprints, V) + grads = torch.autograd.grad( + L, [fp["X"] for fp in fingerprints], + retain_graph=False, create_graph=False, allow_unused=True + ) + with torch.no_grad(): + for fp, g in zip(fingerprints, grads): + if g is None: + g = torch.zeros_like(fp["X"]) + fp["X"].add_(alpha * g) + fp["X"].clamp_(-5.0, 5.0) + +def edge_flip_candidates(A: torch.Tensor, budget: int): + n = A.size(0) + tri_i, tri_j = torch.triu_indices(n, n, offset=1) + scores = torch.abs(0.5 - A[tri_i, tri_j]) + order = torch.argsort(scores) + picks = order[:min(budget, len(order))] + return tri_i[picks], tri_j[picks] + +def edge_flip_step(models_pos, models_neg, fingerprints, V, flip_k=8): + # Rank-and-flip edges by gain in the full loss L when flipping entries in ONE fp at a time + for fp_idx, fp in enumerate(fingerprints): + A = fp["A"] + i_idx, j_idx = edge_flip_candidates(A, flip_k * 4) # candidate pool + with torch.no_grad(): + base_L, _, _ = compute_loss(models_pos, models_neg, fingerprints, V) + + gains = [] + for i, j in zip(i_idx.tolist(), j_idx.tolist()): + with torch.no_grad(): + old = float(A[i, j]) + new = 1.0 - old + # toggle in place + A[i, j] = new; A[j, i] = new + L_try, _, _ = compute_loss(models_pos, models_neg, fingerprints, V) + gain = float(L_try - base_L) + gains.append((gain, i, j, old)) + # revert + A[i, j] = old; A[j, i] = old + + # Flip the best k edges for this fingerprint + gains.sort(key=lambda x: x[0], reverse=True) + with torch.no_grad(): + for g, i, j, old in gains[:flip_k]: + new = 1.0 - old + A[i, j] = new; A[j, i] = new + A.clamp_(0.0, 1.0) + +def train_verifier_step(models_pos, models_neg, fingerprints, V, opt): + L, Zp, Zn = compute_loss(models_pos, models_neg, fingerprints, V) + loss = -L + opt.zero_grad() + loss.backward() + opt.step() + with torch.no_grad(): + yp = (V(Zp) >= 0.5).float().mean().item() + yn = (V(Zn) < 0.5).float().mean().item() + acc = 0.5 * (yp + yn) + return float(L.item()), acc + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--target_path', default='models/target_model_nc.pt') + ap.add_argument('--target_meta', default='models/target_meta_nc.json') + ap.add_argument('--positives_glob', default='models/positives/nc_ftpr_*.pt,models/positives/distill_nc_*.pt') + ap.add_argument('--negatives_glob', default='models/negatives/negative_nc_*.pt') + + # Hyperparams + ap.add_argument('--P', type=int, default=64) + ap.add_argument('--n', type=int, default=32) # nodes per fingerprint + ap.add_argument('--iters', type=int, default=1000) # alternating iterations + ap.add_argument('--verifier_lr', type=float, default=1e-3) # learning rate for V + ap.add_argument('--e1', type=int, default=1) # epochs for fingerprint updates per alternation + ap.add_argument('--e2', type=int, default=1) # epochs for verifier updates per alternation + ap.add_argument('--alpha_x', type=float, default=0.01) # step size for feature ascent + ap.add_argument('--flip_k', type=int, default=4) # edges flipped per step per fingerprint + ap.add_argument('--seed', type=int, default=0) + ap.add_argument('--m', type=int, default=64) # sampled nodes per fingerprint + + + args = ap.parse_args() + + set_seed(args.seed) + Path('fingerprints').mkdir(parents=True, exist_ok=True) + + # Dataset dims + ds = Planetoid(root='data/cora', name='Cora') + in_dim = ds.num_features + num_classes = ds.num_classes + + # Load {f} and F+ into "positives"; F- separately + meta_t = load_meta(args.target_meta) + target = get_model(meta_t["arch"], in_dim, meta_t["hidden"], meta_t["num_classes"]) + target.load_state_dict(torch.load(args.target_path, map_location='cpu')) + target.eval() + + pos_globs = [g.strip() for g in args.positives_glob.split(',') if g.strip()] + pos_paths = list_paths_from_globs(pos_globs) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + models_pos = [target] + for p in pos_paths: + m,_ = load_model_from_pair(p, in_dim) + models_pos.append(m) + + models_neg = [] + for npath in neg_paths: + m,_ = load_model_from_pair(npath, in_dim) + models_neg.append(m) + + print(f"[loaded] positives={len(models_pos)} (incl. target) | negatives={len(models_neg)}") + + # Initialize fingerprints with small random X, sparse A near 0.5 + fingerprints = [] + if args.m > args.n: + raise ValueError(f"--m ({args.m}) must be <= --n ({args.n})") + + for _ in range(args.P): + X = torch.randn(args.n, in_dim) * 0.1 + A = torch.rand(args.n, args.n) * 0.2 + 0.4 + A = torch.triu(A, diagonal=1) + A = A + A.t() + torch.diagonal(A).zero_() + idx = torch.randperm(args.n)[:args.m] + fingerprints.append({"X": X, "A": A, "node_idx": idx}) + + + ver_in_dim = args.P * args.m * num_classes + V = FPVerifier(ver_in_dim) + optV = torch.optim.Adam(V.parameters(), lr=args.verifier_lr) + + flag = 0 + for it in range(1, args.iters + 1): + if flag == 0: + # Update fingerprints (features + edges), e1 times + for _ in range(args.e1): + feature_ascent_step(models_pos, models_neg, fingerprints, V, alpha=args.alpha_x) + edge_flip_step(models_pos, models_neg, fingerprints, V, flip_k=args.flip_k) + flag = 1 + else: + # Update verifier, e2 times + diag_acc = None + for _ in range(args.e2): + Lval, acc = train_verifier_step(models_pos, models_neg, fingerprints, V, optV) + diag_acc = acc + flag = 0 + + if it % 10 == 0 and 'diag_acc' in locals() and diag_acc is not None: + print(f"[Iter {it}] verifier acc={diag_acc:.3f} (diagnostic)") + + clean_fps = [] + for fp in fingerprints: + clean_fps.append({ + "X": fp["X"].detach().clone(), + "A": fp["A"].detach().clone(), + "node_idx": fp["node_idx"].detach().clone(), + }) + torch.save( + {"fingerprints": clean_fps, "verifier": V.state_dict(), "ver_in_dim": ver_in_dim}, + "fingerprints/fingerprints_nc.pt" + ) + + print("Saved fingerprints/fingerprints_nc.pt") + +if __name__ == '__main__': + start_time = time.time() + main() + end_time = time.time() + + print("Time taken: ", (end_time - start_time)/60) + diff --git a/examples/node_class/gcn_nc.py b/examples/node_class/gcn_nc.py new file mode 100644 index 0000000..0142be8 --- /dev/null +++ b/examples/node_class/gcn_nc.py @@ -0,0 +1,85 @@ + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool + +# Common heads +class MLPHead(nn.Module): + def __init__(self, in_dim, out_dim, hidden=64, dropout=0.5): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, hidden), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden, out_dim) + ) + def forward(self, x): + return self.net(x) + +class GCN(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5): + super().__init__() + self.conv1 = GCNConv(in_channels, hidden_channels, cached=False, add_self_loops=True, normalize=True) + self.conv2 = GCNConv(hidden_channels, out_channels, cached=False, add_self_loops=True, normalize=True) + self.dropout = dropout + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + return x # logits for node classes + +class GraphSAGE(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5): + super().__init__() + self.conv1 = SAGEConv(in_channels, hidden_channels) + self.conv2 = SAGEConv(hidden_channels, out_channels) + self.dropout = dropout + + def forward(self, x, edge_index): + x = self.conv1(x, edge_index) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + return x + +class GAT(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, heads=8, dropout=0.6): + super().__init__() + self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout) + self.conv2 = GATConv(hidden_channels*heads, out_channels, heads=1, concat=False, dropout=dropout) + self.dropout = dropout + + def forward(self, x, edge_index): + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv1(x, edge_index) + x = F.elu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv2(x, edge_index) + return x + +class NodeMLP(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, dropout=0.5): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(in_channels, hidden_channels), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(hidden_channels, out_channels) + ) + def forward(self, x, edge_index): + return self.mlp(x) + +def get_model(arch:str, in_dim:int, hidden:int, out_dim:int): + arch = arch.lower() + if arch == 'gcn': + return GCN(in_dim, hidden, out_dim) + if arch == 'sage' or arch == 'graphsage': + return GraphSAGE(in_dim, hidden, out_dim) + if arch == 'gat': + return GAT(in_dim, hidden, out_dim) + if arch == 'mlp': + return NodeMLP(in_dim, hidden, out_dim) + raise ValueError(f"Unknown arch: {arch}") diff --git a/examples/node_class/generate_univerifier_dataset_nc.py b/examples/node_class/generate_univerifier_dataset_nc.py new file mode 100644 index 0000000..4d6f269 --- /dev/null +++ b/examples/node_class/generate_univerifier_dataset_nc.py @@ -0,0 +1,108 @@ +""" +Build a Univerifier dataset from saved fingerprints. +Label 1 for positives ({target ∪ F+}) and 0 for negatives (F−). +Outputs: a .pt file with: + - X: [N_models, D] tensor, where D = P * m * num_classes + - y: [N_models] float tensor with 1.0 (positive) or 0.0 (negative) +""" + +import argparse, glob, json, torch +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import dense_to_sparse +from gcn_nc import get_model + +@torch.no_grad() +def forward_on_fp(model, fp): + X = fp["X"] + A = fp["A"] + idx = fp["node_idx"] + + # Binarize & symmetrize adjacency + A_bin = (A > 0.5).float() + A_sym = torch.triu(A_bin, diagonal=1) + A_sym = A_sym + A_sym.t() + edge_index = dense_to_sparse(A_sym)[0] + + logits = model(X, edge_index) + sel = logits[idx, :] + return sel.reshape(-1) + + +@torch.no_grad() +def concat_for_model(model, fps): + parts = [forward_on_fp(model, fp) for fp in fps] + return torch.cat(parts, dim=0) + + +def list_paths_from_globs(globs_str): + globs = [g.strip() for g in globs_str.split(",") if g.strip()] + paths = [] + for g in globs: + paths.extend(glob.glob(g)) + return sorted(paths) + + +def load_model_from_pt(pt_path, in_dim): + meta_path = pt_path.replace(".pt", ".json") + j = json.load(open(meta_path, "r")) + m = get_model(j["arch"], in_dim, j["hidden"], j["num_classes"]) + m.load_state_dict(torch.load(pt_path, map_location="cpu")) + m.eval() + return m + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--fingerprints_path", type=str, default="fingerprints/fingerprints_nc.pt") + ap.add_argument("--target_path", type=str, default="models/target_model_nc.pt") + ap.add_argument("--target_meta", type=str, default="models/target_meta_nc.json") + ap.add_argument("--positives_glob", type=str, + default="models/positives/nc_ftpr_*.pt,models/positives/distill_nc_*.pt") + ap.add_argument("--negatives_glob", type=str, default="models/negatives/negative_nc_*.pt") + ap.add_argument("--out", type=str, default="fingerprints/univerifier_dataset_nc.pt") + args = ap.parse_args() + + # Dataset dims (for model reconstruction) + ds = Planetoid(root="data/cora", name="Cora") + in_dim = ds.num_features + num_classes = ds.num_classes + + pack = torch.load(args.fingerprints_path, map_location="cpu") + fps = pack["fingerprints"] + ver_in_dim_saved = pack.get("ver_in_dim", None) + + + tmeta = json.load(open(args.target_meta, "r")) + target = get_model(tmeta["arch"], in_dim, tmeta["hidden"], tmeta["num_classes"]) + target.load_state_dict(torch.load(args.target_path, map_location="cpu")) + target.eval() + + pos_paths = list_paths_from_globs(args.positives_glob) + neg_paths = sorted(glob.glob(args.negatives_glob)) + + models = [target] + labels = [1.0] + + for p in pos_paths: + models.append(load_model_from_pt(p, in_dim)); labels.append(1.0) + for n in neg_paths: + models.append(load_model_from_pt(n, in_dim)); labels.append(0.0) + + with torch.no_grad(): + z0 = concat_for_model(models[0], fps) + D = z0.numel() + if ver_in_dim_saved is not None and D != int(ver_in_dim_saved): + raise RuntimeError( + f"Verifier input mismatch: dataset dim {D} vs saved ver_in_dim {ver_in_dim_saved}" + ) + + X_rows = [z0] + [concat_for_model(m, fps) for m in models[1:]] + X = torch.stack(X_rows, dim=0).float() + y = torch.tensor(labels, dtype=torch.float32) + + torch.save({"X": X, "y": y}, args.out) + print(f"Saved {args.out} with {X.shape[0]} rows; dim={X.shape[1]} (num_classes={num_classes})") + print(f"Positives: {int(sum(labels))} | Negatives: {len(labels) - int(sum(labels))}") + +if __name__ == "__main__": + main() diff --git a/examples/node_class/make_suspect_nc.py b/examples/node_class/make_suspect_nc.py new file mode 100644 index 0000000..cabcbeb --- /dev/null +++ b/examples/node_class/make_suspect_nc.py @@ -0,0 +1,74 @@ +import argparse, json, random, torch +import torch.nn.functional as F +from pathlib import Path +from torch_geometric.datasets import Planetoid +from gcn_nc import get_model + +def set_seed(s): + random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s) + +def make_masks(n, train_p=0.7, val_p=0.1, seed=0): + g = torch.Generator().manual_seed(seed) + idx = torch.randperm(n, generator=g) + n_tr = int(train_p*n); n_va = int(val_p*n) + tr, va, te = idx[:n_tr], idx[n_tr:n_tr+n_va], idx[n_tr+n_va:] + mtr = torch.zeros(n, dtype=torch.bool); mtr[tr]=True + mva = torch.zeros(n, dtype=torch.bool); mva[va]=True + mte = torch.zeros(n, dtype=torch.bool); mte[te]=True + return mtr, mva, mte + +def train_epoch(model, data, opt, mask): + model.train(); opt.zero_grad() + out = model(data.x, data.edge_index) + loss = F.cross_entropy(out[mask], data.y[mask]); loss.backward(); opt.step() + return float(loss.item()) + +@torch.no_grad() +def eval_mask(model, data, mask): + model.eval(); out = model(data.x, data.edge_index) + pred = out.argmax(dim=-1) + return float((pred[mask]==data.y[mask]).float().mean()) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--arch', default='sage', help='gcn or sage (unrelated to target)') + ap.add_argument('--hidden', type=int, default=64) + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=9999, help='use a NEW seed unseen by the Univerifier') + ap.add_argument('--out_dir', default='models/suspects') + ap.add_argument('--name', default='neg_nc_seed9999') + args = ap.parse_args() + + set_seed(args.seed) + ds = Planetoid(root='data/cora', name='Cora') + data = ds[0] + mtr, mva, mte = make_masks(data.num_nodes, 0.7, 0.1, seed=args.seed) + data.train_mask, data.val_mask, data.test_mask = mtr, mva, mte + + model = get_model(args.arch, ds.num_features, args.hidden, ds.num_classes) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) + + best_val, best_state = -1.0, None + for _ in range(args.epochs): + _ = train_epoch(model, data, opt, data.train_mask) + val = eval_mask(model, data, data.val_mask) + if val > best_val: + best_val = val + best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} + model.load_state_dict(best_state) + + Path(args.out_dir).mkdir(parents=True, exist_ok=True) + pt = f"{args.out_dir}/{args.name}.pt" + meta = { + "arch": args.arch, "hidden": args.hidden, + "in_dim": ds.num_features, "num_classes": ds.num_classes, + "seed": args.seed, "note": "never-seen negative suspect" + } + torch.save(model.state_dict(), pt) + with open(pt.replace('.pt','.json'), 'w') as f: json.dump(meta, f) + print(f"[saved] {pt} (val_acc={best_val:.4f})") + +if __name__ == '__main__': + main() diff --git a/examples/node_class/score_suspect_nc.py b/examples/node_class/score_suspect_nc.py new file mode 100644 index 0000000..ca0005f --- /dev/null +++ b/examples/node_class/score_suspect_nc.py @@ -0,0 +1,87 @@ +import argparse, json, os, torch +import torch.nn as nn +from torch_geometric.utils import dense_to_sparse + +# Univerifier +class FPVerifier(nn.Module): + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), nn.LeakyReLU(), + nn.Linear(128, 64), nn.LeakyReLU(), + nn.Linear(64, 32), nn.LeakyReLU(), + nn.Linear(32, 1), nn.Sigmoid(), + ) + def forward(self, x): return self.net(x) + +def load_json(p): + with open(p, "r") as f: return json.load(f) + +def edge_index_from_A(A: torch.Tensor) -> torch.Tensor: + A_bin = (A > 0.5).float() + A_sym = torch.triu(A_bin, diagonal=1); A_sym = A_sym + A_sym.t() + ei = dense_to_sparse(A_sym)[0] + if ei.numel() == 0: + n = A.size(0); ei = torch.arange(n).repeat(2,1) + return ei + +@torch.no_grad() +def build_z_nc(model: nn.Module, fps): + parts = [] + for fp in fps: + X, A, idx = fp["X"], fp["A"], fp["node_idx"] + ei = edge_index_from_A(A) + logits = model(X, ei) + parts.append(logits[idx, :].reshape(-1)) + return torch.cat(parts, dim=0) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--fingerprints_path", default="fingerprints/fingerprints_nc.pt") + ap.add_argument("--verifier_path", default="fingerprints/univerifier_nc.pt", + help="If missing, load 'verifier' from fingerprints pack.") + ap.add_argument("--suspect_pt", required=True) + ap.add_argument("--suspect_meta", required=False, default="") + ap.add_argument("--threshold", type=float, default=0.5) + ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") + ap.add_argument("--in_dim", type=int, default=1433) + ap.add_argument("--num_classes", type=int, default=7) + args = ap.parse_args() + + device = torch.device(args.device) + pack = torch.load(args.fingerprints_path, map_location="cpu") + fps = pack["fingerprints"]; ver_in_dim = int(pack.get("ver_in_dim", 0)) + + # Build suspect NC model + from gcn_nc import get_model + meta = load_json(args.suspect_meta) if args.suspect_meta else {} + arch = meta.get("arch", "gcn") + hidden = int(meta.get("hidden", 64)) + in_dim = int(meta.get("in_dim", args.in_dim)) + num_classes = int(meta.get("num_classes", args.num_classes)) + model = get_model(arch, in_dim, hidden, num_classes).to(device) + model.load_state_dict(torch.load(args.suspect_pt, map_location="cpu")) + model.eval() + + z = build_z_nc(model, fps) + D = z.numel() + if ver_in_dim and ver_in_dim != D: + raise RuntimeError(f"Dim mismatch: verifier expects {ver_in_dim}, got {D}.") + + # Load verifier + V = FPVerifier(D).to(device) + if os.path.isfile(args.verifier_path): + V.load_state_dict(torch.load(args.verifier_path, map_location="cpu")) + src = args.verifier_path + else: + if "verifier" not in pack: raise FileNotFoundError("No verifier found.") + V.load_state_dict(pack["verifier"]); src = f"{args.fingerprints_path}:[verifier]" + V.eval() + + with torch.no_grad(): + s = float(V(z.view(1, -1).to(device)).item()) + verdict = "OWNED (positive)" if s >= args.threshold else "NOT-OWNED (negative)" + print(f"Score={s:.6f} | τ={args.threshold:.3f} -> {verdict}") + +if __name__ == "__main__": + main() diff --git a/examples/node_class/train_nc.py b/examples/node_class/train_nc.py new file mode 100644 index 0000000..74529ed --- /dev/null +++ b/examples/node_class/train_nc.py @@ -0,0 +1,79 @@ +import argparse, torch, random +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +from torch_geometric.utils import add_self_loops +from torch_geometric.loader import NeighborLoader +from gcn_nc import get_model + +def set_seed(seed): + random.seed(seed); + torch.manual_seed(seed); + torch.cuda.manual_seed_all(seed) + +def make_masks(num_nodes, train_p=0.7, val_p=0.1, seed=0): + g = torch.Generator().manual_seed(seed) + idx = torch.randperm(num_nodes, generator=g) + n_train = int(train_p * num_nodes) + n_val = int(val_p * num_nodes) + train_idx = idx[:n_train]; val_idx = idx[n_train:n_train+n_val]; test_idx = idx[n_train+n_val:] + train_mask = torch.zeros(num_nodes, dtype=torch.bool); train_mask[train_idx]=True + val_mask = torch.zeros(num_nodes, dtype=torch.bool); val_mask[val_idx]=True + test_mask = torch.zeros(num_nodes, dtype=torch.bool); test_mask[test_idx]=True + return train_mask, val_mask, test_mask + +def train_epoch(model, data, optimizer, train_mask): + model.train() + optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = F.cross_entropy(out[train_mask], data.y[train_mask]) + loss.backward() + optimizer.step() + return float(loss.item()) + +@torch.no_grad() +def eval_masks(model, data, mask): + model.eval() + out = model(data.x, data.edge_index) + pred = out.argmax(dim=-1) + correct = int((pred[mask] == data.y[mask]).sum()) + total = int(mask.sum()) + return correct/total if total>0 else 0.0 + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--arch', type=str, default='gcn') + ap.add_argument('--hidden', type=int, default=64) + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--weight_decay', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=0) + args = ap.parse_args() + + set_seed(args.seed) + dataset = Planetoid(root='data/cora', name='Cora') + data = dataset[0] + + train_mask, val_mask, test_mask = make_masks(data.num_nodes, 0.7, 0.1, seed=args.seed) + data.train_mask, data.val_mask, data.test_mask = train_mask, val_mask, test_mask + + model = get_model(args.arch, data.num_features, args.hidden, dataset.num_classes) + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + best_val, best_state = -1, None + for epoch in range(1, args.epochs+1): + loss = train_epoch(model, data, optimizer, data.train_mask) + val_acc = eval_masks(model, data, data.val_mask) + if val_acc > best_val: + best_val, best_state = val_acc, {k:v.cpu().clone() for k,v in model.state_dict().items()} + if epoch % 20 == 0 or epoch == args.epochs: + print(f"Epoch {epoch:03d} | loss {loss:.4f} | val {val_acc:.4f}") + + model.load_state_dict(best_state) + test_acc = eval_masks(model, data, data.test_mask) + print(f"Best Val Acc: {best_val:.4f} | Test Acc: {test_acc:.4f}") + torch.save(model.state_dict(), 'models/target_model_nc.pt') + with open('models/target_meta_nc.json','w') as f: + f.write(f'{{"arch":"{args.arch}","hidden":{args.hidden},"num_classes":{dataset.num_classes}}}') + +if __name__ == '__main__': + main() diff --git a/examples/node_class/train_univerifier_nc.py b/examples/node_class/train_univerifier_nc.py new file mode 100644 index 0000000..989bb66 --- /dev/null +++ b/examples/node_class/train_univerifier_nc.py @@ -0,0 +1,98 @@ +""" +Trains the Univerifier on features built from fingerprints (MLP: [128,64,32] + LeakyReLU). +Loads X,y from generate_univerifier_dataset.py and saves weights + a tiny meta JSON. +""" + +import argparse, json, torch, time +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path + +class FPVerifier(nn.Module): + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid(), + ) + def forward(self, x): + return self.net(x) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--dataset', type=str, default='fingerprints/univerifier_dataset_nc.pt') + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--weight_decay', type=float, default=0.0) + ap.add_argument('--val_split', type=float, default=0.2) + ap.add_argument('--fingerprints_path', type=str, default='fingerprints/fingerprints_nc.pt') + ap.add_argument('--out', type=str, default='fingerprints/univerifier_nc.pt') + args = ap.parse_args() + + # Load dataset + pack = torch.load(args.dataset, map_location='cpu') + X = pack['X'].float().detach() + y = pack['y'].float().view(-1, 1).detach() + N, D = X.shape + + try: + fp_pack = torch.load(args.fingerprints_path, map_location='cpu') + ver_in_dim = int(fp_pack.get('ver_in_dim', D)) + if ver_in_dim != D: + raise RuntimeError(f'Input dim mismatch: dataset dim {D} vs ver_in_dim {ver_in_dim}') + except FileNotFoundError: + pass + + # Train/val split + n_val = max(1, int(args.val_split * N)) + perm = torch.randperm(N) + idx_tr, idx_val = perm[:-n_val], perm[-n_val:] + X_tr, y_tr = X[idx_tr], y[idx_tr] + X_val, y_val = X[idx_val], y[idx_val] + + # Model/optim + V = FPVerifier(D) + opt = torch.optim.Adam(V.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + best_acc, best_state = 0.0, None + for ep in range(1, args.epochs + 1): + V.train(); opt.zero_grad() + p = V(X_tr) + loss = F.binary_cross_entropy(p, y_tr) + loss.backward(); opt.step() + + with torch.no_grad(): + V.eval() + pv = V(X_val) + val_loss = F.binary_cross_entropy(pv, y_val) + val_acc = ((pv >= 0.5).float() == y_val).float().mean().item() + + if val_acc > best_acc: + best_acc = val_acc + best_state = {k: v.cpu().clone() for k, v in V.state_dict().items()} + + if ep % 20 == 0 or ep == args.epochs: + print(f'Epoch {ep:03d} | train_bce {loss.item():.4f} ' + f'| val_bce {val_loss.item():.4f} | val_acc {val_acc:.4f}') + + # Save best + if best_state is None: + best_state = V.state_dict() + Path('fingerprints').mkdir(exist_ok=True, parents=True) + torch.save(best_state, args.out) + with open(args.out.replace('.pt', '_meta.json'), 'w') as f: + json.dump({'in_dim': D, 'hidden': [128, 64, 32], 'act': 'LeakyReLU'}, f) + print(f'Saved {args.out} | Best Val Acc {best_acc:.4f} | Input dim D={D}') + +if __name__ == '__main__': + start_time = time.time() + main() + end_time = time.time() + print("time taken: ", (end_time-start_time)/60 ) + diff --git a/examples/node_class/train_unrelated_nc.py b/examples/node_class/train_unrelated_nc.py new file mode 100644 index 0000000..157b72b --- /dev/null +++ b/examples/node_class/train_unrelated_nc.py @@ -0,0 +1,88 @@ + +""" +Negative models: different random seeds and/or architectures trained from scratch on the same train split. +""" +import argparse, torch, random, json +from pathlib import Path +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +from gcn_nc import get_model + +def set_seed(seed): + random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed) + +def make_masks(num_nodes, train_p=0.7, val_p=0.1, seed=0): + g = torch.Generator().manual_seed(seed) + idx = torch.randperm(num_nodes, generator=g) + n_train = int(train_p * num_nodes) + n_val = int(val_p * num_nodes) + train_idx = idx[:n_train]; val_idx = idx[n_train:n_train+n_val]; test_idx = idx[n_train+n_val:] + train_mask = torch.zeros(num_nodes, dtype=torch.bool); train_mask[train_idx]=True + val_mask = torch.zeros(num_nodes, dtype=torch.bool); val_mask[val_idx]=True + test_mask = torch.zeros(num_nodes, dtype=torch.bool); test_mask[test_idx]=True + return train_mask, val_mask, test_mask + +def train_epoch(model, data, optimizer, mask): + model.train(); optimizer.zero_grad() + out = model(data.x, data.edge_index) + loss = F.cross_entropy(out[mask], data.y[mask]) + loss.backward(); optimizer.step() + return float(loss.item()) + +@torch.no_grad() +def eval_mask(model, data, mask): + model.eval(); out = model(data.x, data.edge_index) + pred = out.argmax(dim=-1) + return float((pred[mask]==data.y[mask]).float().mean()) + +def save_model(model, path, meta): + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) # <-- ensure folder exists + torch.save(model.state_dict(), str(path)) + with open(str(path).replace('.pt', '.json'), 'w') as f: + json.dump(meta, f) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--count', type=int, default=50) + ap.add_argument('--archs', type=str, default='gcn,sage') + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=0.01) + ap.add_argument('--wd', type=float, default=5e-4) + ap.add_argument('--seed', type=int, default=123) + ap.add_argument('--out_dir', type=str, default='models/negatives') # <-- where to save + args = ap.parse_args() + + dataset = Planetoid(root='data/cora', name='Cora') + data = dataset[0] + + saved = [] + arch_list = args.archs.split(',') + + for i in range(args.count): + seed_i = args.seed + i + set_seed(seed_i) + train_mask, val_mask, test_mask = make_masks(data.num_nodes, 0.7, 0.1, seed=seed_i) + data.train_mask, data.val_mask, data.test_mask = train_mask, val_mask, test_mask + + arch = arch_list[i % len(arch_list)] + model = get_model(arch, data.num_features, 64, dataset.num_classes) + opt = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) + + best_val, best_state = -1, None + for ep in range(args.epochs): + loss = train_epoch(model, data, opt, data.train_mask) + val = eval_mask(model, data, data.val_mask) + if val > best_val: + best_val, best_state = val, {k:v.cpu().clone() for k,v in model.state_dict().items()} + model.load_state_dict(best_state) + + out_path = Path(args.out_dir) / f"negative_nc_{i:03d}.pt" + meta = {"arch": arch, "hidden": 64, "num_classes": dataset.num_classes, "seed": seed_i} + save_model(model, out_path, meta) + + saved.append(str(out_path)) + print(f"Saved negative {i} arch={arch} val={best_val:.4f} -> {out_path}") + +if __name__ == '__main__': + main() diff --git a/examples/plots/citeseer_lp_aruc.png b/examples/plots/citeseer_lp_aruc.png new file mode 100644 index 0000000..f7f2e74 Binary files /dev/null and b/examples/plots/citeseer_lp_aruc.png differ diff --git a/examples/plots/cora_nc_aruc.csv b/examples/plots/cora_nc_aruc.csv new file mode 100644 index 0000000..f1a17ec --- /dev/null +++ b/examples/plots/cora_nc_aruc.csv @@ -0,0 +1,202 @@ +threshold,robustness,uniqueness,min_curve,accuracy +0.00000,1.000000,0.000000,0.000000,0.501247 +0.00500,1.000000,0.845000,0.845000,0.922693 +0.01000,1.000000,0.875000,0.875000,0.937656 +0.01500,1.000000,0.890000,0.890000,0.945137 +0.02000,1.000000,0.890000,0.890000,0.945137 +0.02500,1.000000,0.895000,0.895000,0.947631 +0.03000,1.000000,0.905000,0.905000,0.952618 +0.03500,1.000000,0.920000,0.920000,0.960100 +0.04000,1.000000,0.920000,0.920000,0.960100 +0.04500,1.000000,0.920000,0.920000,0.960100 +0.05000,1.000000,0.920000,0.920000,0.960100 +0.05500,1.000000,0.925000,0.925000,0.962594 +0.06000,1.000000,0.930000,0.930000,0.965087 +0.06500,1.000000,0.930000,0.930000,0.965087 +0.07000,1.000000,0.935000,0.935000,0.967581 +0.07500,1.000000,0.935000,0.935000,0.967581 +0.08000,1.000000,0.935000,0.935000,0.967581 +0.08500,1.000000,0.935000,0.935000,0.967581 +0.09000,1.000000,0.935000,0.935000,0.967581 +0.09500,1.000000,0.935000,0.935000,0.967581 +0.10000,1.000000,0.935000,0.935000,0.967581 +0.10500,1.000000,0.940000,0.940000,0.970075 +0.11000,1.000000,0.940000,0.940000,0.970075 +0.11500,1.000000,0.945000,0.945000,0.972569 +0.12000,1.000000,0.945000,0.945000,0.972569 +0.12500,1.000000,0.945000,0.945000,0.972569 +0.13000,1.000000,0.945000,0.945000,0.972569 +0.13500,1.000000,0.945000,0.945000,0.972569 +0.14000,1.000000,0.945000,0.945000,0.972569 +0.14500,0.995025,0.945000,0.945000,0.970075 +0.15000,0.995025,0.950000,0.950000,0.972569 +0.15500,0.995025,0.955000,0.955000,0.975062 +0.16000,0.995025,0.955000,0.955000,0.975062 +0.16500,0.995025,0.955000,0.955000,0.975062 +0.17000,0.995025,0.955000,0.955000,0.975062 +0.17500,0.995025,0.955000,0.955000,0.975062 +0.18000,0.995025,0.960000,0.960000,0.977556 +0.18500,0.995025,0.965000,0.965000,0.980050 +0.19000,0.995025,0.965000,0.965000,0.980050 +0.19500,0.995025,0.970000,0.970000,0.982544 +0.20000,0.995025,0.970000,0.970000,0.982544 +0.20500,0.995025,0.970000,0.970000,0.982544 +0.21000,0.995025,0.970000,0.970000,0.982544 +0.21500,0.995025,0.970000,0.970000,0.982544 +0.22000,0.995025,0.975000,0.975000,0.985037 +0.22500,0.995025,0.975000,0.975000,0.985037 +0.23000,0.995025,0.975000,0.975000,0.985037 +0.23500,0.995025,0.975000,0.975000,0.985037 +0.24000,0.995025,0.975000,0.975000,0.985037 +0.24500,0.995025,0.975000,0.975000,0.985037 +0.25000,0.995025,0.975000,0.975000,0.985037 +0.25500,0.995025,0.975000,0.975000,0.985037 +0.26000,0.995025,0.980000,0.980000,0.987531 +0.26500,0.995025,0.980000,0.980000,0.987531 +0.27000,0.995025,0.980000,0.980000,0.987531 +0.27500,0.995025,0.980000,0.980000,0.987531 +0.28000,0.995025,0.980000,0.980000,0.987531 +0.28500,0.995025,0.980000,0.980000,0.987531 +0.29000,0.995025,0.980000,0.980000,0.987531 +0.29500,0.995025,0.980000,0.980000,0.987531 +0.30000,0.995025,0.980000,0.980000,0.987531 +0.30500,0.995025,0.980000,0.980000,0.987531 +0.31000,0.995025,0.980000,0.980000,0.987531 +0.31500,0.995025,0.980000,0.980000,0.987531 +0.32000,0.995025,0.980000,0.980000,0.987531 +0.32500,0.995025,0.980000,0.980000,0.987531 +0.33000,0.995025,0.980000,0.980000,0.987531 +0.33500,0.995025,0.980000,0.980000,0.987531 +0.34000,0.995025,0.980000,0.980000,0.987531 +0.34500,0.995025,0.980000,0.980000,0.987531 +0.35000,0.995025,0.980000,0.980000,0.987531 +0.35500,0.995025,0.980000,0.980000,0.987531 +0.36000,0.995025,0.980000,0.980000,0.987531 +0.36500,0.995025,0.980000,0.980000,0.987531 +0.37000,0.995025,0.980000,0.980000,0.987531 +0.37500,0.995025,0.980000,0.980000,0.987531 +0.38000,0.995025,0.980000,0.980000,0.987531 +0.38500,0.995025,0.980000,0.980000,0.987531 +0.39000,0.995025,0.980000,0.980000,0.987531 +0.39500,0.995025,0.980000,0.980000,0.987531 +0.40000,0.995025,0.980000,0.980000,0.987531 +0.40500,0.995025,0.980000,0.980000,0.987531 +0.41000,0.995025,0.980000,0.980000,0.987531 +0.41500,0.990050,0.980000,0.980000,0.985037 +0.42000,0.990050,0.980000,0.980000,0.985037 +0.42500,0.990050,0.980000,0.980000,0.985037 +0.43000,0.990050,0.980000,0.980000,0.985037 +0.43500,0.990050,0.980000,0.980000,0.985037 +0.44000,0.990050,0.980000,0.980000,0.985037 +0.44500,0.990050,0.980000,0.980000,0.985037 +0.45000,0.990050,0.980000,0.980000,0.985037 +0.45500,0.990050,0.980000,0.980000,0.985037 +0.46000,0.990050,0.980000,0.980000,0.985037 +0.46500,0.990050,0.985000,0.985000,0.987531 +0.47000,0.990050,0.985000,0.985000,0.987531 +0.47500,0.990050,0.985000,0.985000,0.987531 +0.48000,0.990050,0.985000,0.985000,0.987531 +0.48500,0.990050,0.985000,0.985000,0.987531 +0.49000,0.990050,0.985000,0.985000,0.987531 +0.49500,0.990050,0.985000,0.985000,0.987531 +0.50000,0.990050,0.990000,0.990000,0.990025 +0.50500,0.990050,0.990000,0.990000,0.990025 +0.51000,0.990050,0.990000,0.990000,0.990025 +0.51500,0.990050,0.990000,0.990000,0.990025 +0.52000,0.985075,0.990000,0.985075,0.987531 +0.52500,0.985075,0.990000,0.985075,0.987531 +0.53000,0.985075,0.990000,0.985075,0.987531 +0.53500,0.985075,0.990000,0.985075,0.987531 +0.54000,0.985075,0.990000,0.985075,0.987531 +0.54500,0.985075,0.990000,0.985075,0.987531 +0.55000,0.985075,0.990000,0.985075,0.987531 +0.55500,0.985075,0.990000,0.985075,0.987531 +0.56000,0.985075,0.990000,0.985075,0.987531 +0.56500,0.980100,0.990000,0.980100,0.985037 +0.57000,0.980100,0.990000,0.980100,0.985037 +0.57500,0.980100,0.990000,0.980100,0.985037 +0.58000,0.980100,0.990000,0.980100,0.985037 +0.58500,0.980100,0.990000,0.980100,0.985037 +0.59000,0.980100,0.990000,0.980100,0.985037 +0.59500,0.980100,0.990000,0.980100,0.985037 +0.60000,0.980100,0.990000,0.980100,0.985037 +0.60500,0.980100,0.990000,0.980100,0.985037 +0.61000,0.980100,0.990000,0.980100,0.985037 +0.61500,0.980100,0.990000,0.980100,0.985037 +0.62000,0.975124,0.990000,0.975124,0.982544 +0.62500,0.975124,0.995000,0.975124,0.985037 +0.63000,0.975124,0.995000,0.975124,0.985037 +0.63500,0.975124,0.995000,0.975124,0.985037 +0.64000,0.975124,0.995000,0.975124,0.985037 +0.64500,0.975124,0.995000,0.975124,0.985037 +0.65000,0.975124,0.995000,0.975124,0.985037 +0.65500,0.975124,0.995000,0.975124,0.985037 +0.66000,0.975124,0.995000,0.975124,0.985037 +0.66500,0.975124,0.995000,0.975124,0.985037 +0.67000,0.975124,0.995000,0.975124,0.985037 +0.67500,0.975124,0.995000,0.975124,0.985037 +0.68000,0.975124,0.995000,0.975124,0.985037 +0.68500,0.975124,0.995000,0.975124,0.985037 +0.69000,0.975124,0.995000,0.975124,0.985037 +0.69500,0.970149,0.995000,0.970149,0.982544 +0.70000,0.970149,0.995000,0.970149,0.982544 +0.70500,0.970149,0.995000,0.970149,0.982544 +0.71000,0.970149,0.995000,0.970149,0.982544 +0.71500,0.965174,0.995000,0.965174,0.980050 +0.72000,0.965174,0.995000,0.965174,0.980050 +0.72500,0.965174,0.995000,0.965174,0.980050 +0.73000,0.965174,0.995000,0.965174,0.980050 +0.73500,0.965174,0.995000,0.965174,0.980050 +0.74000,0.965174,0.995000,0.965174,0.980050 +0.74500,0.965174,0.995000,0.965174,0.980050 +0.75000,0.965174,0.995000,0.965174,0.980050 +0.75500,0.965174,0.995000,0.965174,0.980050 +0.76000,0.965174,0.995000,0.965174,0.980050 +0.76500,0.965174,0.995000,0.965174,0.980050 +0.77000,0.960199,0.995000,0.960199,0.977556 +0.77500,0.960199,0.995000,0.960199,0.977556 +0.78000,0.955224,0.995000,0.955224,0.975062 +0.78500,0.955224,0.995000,0.955224,0.975062 +0.79000,0.950249,0.995000,0.950249,0.972569 +0.79500,0.950249,0.995000,0.950249,0.972569 +0.80000,0.950249,0.995000,0.950249,0.972569 +0.80500,0.950249,0.995000,0.950249,0.972569 +0.81000,0.945274,0.995000,0.945274,0.970075 +0.81500,0.945274,1.000000,0.945274,0.972569 +0.82000,0.940299,1.000000,0.940299,0.970075 +0.82500,0.940299,1.000000,0.940299,0.970075 +0.83000,0.940299,1.000000,0.940299,0.970075 +0.83500,0.940299,1.000000,0.940299,0.970075 +0.84000,0.935323,1.000000,0.935323,0.967581 +0.84500,0.930348,1.000000,0.930348,0.965087 +0.85000,0.930348,1.000000,0.930348,0.965087 +0.85500,0.930348,1.000000,0.930348,0.965087 +0.86000,0.930348,1.000000,0.930348,0.965087 +0.86500,0.930348,1.000000,0.930348,0.965087 +0.87000,0.930348,1.000000,0.930348,0.965087 +0.87500,0.930348,1.000000,0.930348,0.965087 +0.88000,0.930348,1.000000,0.930348,0.965087 +0.88500,0.930348,1.000000,0.930348,0.965087 +0.89000,0.930348,1.000000,0.930348,0.965087 +0.89500,0.930348,1.000000,0.930348,0.965087 +0.90000,0.910448,1.000000,0.910448,0.955112 +0.90500,0.905473,1.000000,0.905473,0.952618 +0.91000,0.905473,1.000000,0.905473,0.952618 +0.91500,0.900498,1.000000,0.900498,0.950125 +0.92000,0.895522,1.000000,0.895522,0.947631 +0.92500,0.895522,1.000000,0.895522,0.947631 +0.93000,0.895522,1.000000,0.895522,0.947631 +0.93500,0.895522,1.000000,0.895522,0.947631 +0.94000,0.895522,1.000000,0.895522,0.947631 +0.94500,0.895522,1.000000,0.895522,0.947631 +0.95000,0.875622,1.000000,0.875622,0.937656 +0.95500,0.875622,1.000000,0.875622,0.937656 +0.96000,0.875622,1.000000,0.875622,0.937656 +0.96500,0.855721,1.000000,0.855721,0.927681 +0.97000,0.855721,1.000000,0.855721,0.927681 +0.97500,0.820896,1.000000,0.820896,0.910224 +0.98000,0.776119,1.000000,0.776119,0.887781 +0.98500,0.711443,1.000000,0.711443,0.855362 +0.99000,0.616915,1.000000,0.616915,0.807980 +0.99500,0.477612,1.000000,0.477612,0.738155 +1.00000,0.000000,1.000000,0.000000,0.498753 diff --git a/examples/plots/cora_nc_aruc.png b/examples/plots/cora_nc_aruc.png new file mode 100644 index 0000000..9488254 Binary files /dev/null and b/examples/plots/cora_nc_aruc.png differ diff --git a/examples/plots/enzymes_gc_aruc.png b/examples/plots/enzymes_gc_aruc.png new file mode 100644 index 0000000..95eb240 Binary files /dev/null and b/examples/plots/enzymes_gc_aruc.png differ diff --git a/examples/plots/note.txt b/examples/plots/note.txt new file mode 100644 index 0000000..c028b32 --- /dev/null +++ b/examples/plots/note.txt @@ -0,0 +1,4 @@ +The plots are generated on small values of hyperparameters for testing purposes. +Hence, the graphs are not identical to the paper ones. +Nevertheless it depicts the pattern. +Use default params for better results. \ No newline at end of file diff --git a/examples/requirements.txt b/examples/requirements.txt new file mode 100644 index 0000000..feb0786 --- /dev/null +++ b/examples/requirements.txt @@ -0,0 +1,50 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +async-timeout==5.0.1 +attrs==25.3.0 +certifi==2025.8.3 +charset-normalizer==3.4.2 +contourpy==1.3.2 +cycler==0.12.1 +filelock==3.18.0 +fonttools==4.59.0 +frozenlist==1.7.0 +fsspec==2025.7.0 +idna==3.10 +Jinja2==3.1.6 +joblib==1.5.1 +kiwisolver==1.4.8 +MarkupSafe==3.0.2 +matplotlib==3.10.5 +mpmath==1.3.0 +multidict==6.6.3 +networkx==3.4.2 +numpy==1.26.4 +packaging==25.0 +pandas==2.3.1 +pillow==11.3.0 +propcache==0.3.2 +psutil==7.0.0 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +pytz==2025.2 +requests==2.32.4 +scikit-learn==1.7.1 +scipy==1.15.3 +six==1.17.0 +sympy==1.14.0 +threadpoolctl==3.6.0 +torch==2.2.0 +torch-geometric==2.6.1 +torch_cluster==1.6.3 +torch_scatter==2.1.2 +torch_sparse==0.6.18 +torch_spline_conv==1.2.2 +torchaudio==2.2.0 +torchvision==0.17.0 +tqdm==4.67.1 +typing_extensions==4.14.1 +tzdata==2025.2 +urllib3==2.5.0 +yarl==1.20.1 diff --git a/examples/train_univerifier.py b/examples/train_univerifier.py new file mode 100644 index 0000000..989bb66 --- /dev/null +++ b/examples/train_univerifier.py @@ -0,0 +1,98 @@ +""" +Trains the Univerifier on features built from fingerprints (MLP: [128,64,32] + LeakyReLU). +Loads X,y from generate_univerifier_dataset.py and saves weights + a tiny meta JSON. +""" + +import argparse, json, torch, time +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path + +class FPVerifier(nn.Module): + def __init__(self, in_dim: int): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_dim, 128), + nn.LeakyReLU(), + nn.Linear(128, 64), + nn.LeakyReLU(), + nn.Linear(64, 32), + nn.LeakyReLU(), + nn.Linear(32, 1), + nn.Sigmoid(), + ) + def forward(self, x): + return self.net(x) + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument('--dataset', type=str, default='fingerprints/univerifier_dataset_nc.pt') + ap.add_argument('--epochs', type=int, default=200) + ap.add_argument('--lr', type=float, default=1e-3) + ap.add_argument('--weight_decay', type=float, default=0.0) + ap.add_argument('--val_split', type=float, default=0.2) + ap.add_argument('--fingerprints_path', type=str, default='fingerprints/fingerprints_nc.pt') + ap.add_argument('--out', type=str, default='fingerprints/univerifier_nc.pt') + args = ap.parse_args() + + # Load dataset + pack = torch.load(args.dataset, map_location='cpu') + X = pack['X'].float().detach() + y = pack['y'].float().view(-1, 1).detach() + N, D = X.shape + + try: + fp_pack = torch.load(args.fingerprints_path, map_location='cpu') + ver_in_dim = int(fp_pack.get('ver_in_dim', D)) + if ver_in_dim != D: + raise RuntimeError(f'Input dim mismatch: dataset dim {D} vs ver_in_dim {ver_in_dim}') + except FileNotFoundError: + pass + + # Train/val split + n_val = max(1, int(args.val_split * N)) + perm = torch.randperm(N) + idx_tr, idx_val = perm[:-n_val], perm[-n_val:] + X_tr, y_tr = X[idx_tr], y[idx_tr] + X_val, y_val = X[idx_val], y[idx_val] + + # Model/optim + V = FPVerifier(D) + opt = torch.optim.Adam(V.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + best_acc, best_state = 0.0, None + for ep in range(1, args.epochs + 1): + V.train(); opt.zero_grad() + p = V(X_tr) + loss = F.binary_cross_entropy(p, y_tr) + loss.backward(); opt.step() + + with torch.no_grad(): + V.eval() + pv = V(X_val) + val_loss = F.binary_cross_entropy(pv, y_val) + val_acc = ((pv >= 0.5).float() == y_val).float().mean().item() + + if val_acc > best_acc: + best_acc = val_acc + best_state = {k: v.cpu().clone() for k, v in V.state_dict().items()} + + if ep % 20 == 0 or ep == args.epochs: + print(f'Epoch {ep:03d} | train_bce {loss.item():.4f} ' + f'| val_bce {val_loss.item():.4f} | val_acc {val_acc:.4f}') + + # Save best + if best_state is None: + best_state = V.state_dict() + Path('fingerprints').mkdir(exist_ok=True, parents=True) + torch.save(best_state, args.out) + with open(args.out.replace('.pt', '_meta.json'), 'w') as f: + json.dump({'in_dim': D, 'hidden': [128, 64, 32], 'act': 'LeakyReLU'}, f) + print(f'Saved {args.out} | Best Val Acc {best_acc:.4f} | Input dim D={D}') + +if __name__ == '__main__': + start_time = time.time() + main() + end_time = time.time() + print("time taken: ", (end_time-start_time)/60 ) +