diff --git a/.gitignore b/.gitignore index e0afa39..117afbc 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,7 @@ dist/ #virtual environments folder .venv +temp_results/ +robustness_results/ +*.pth +*.pt diff --git a/README.md b/README.md index 2b9fa53..41a87fc 100644 --- a/README.md +++ b/README.md @@ -1,118 +1,18 @@ -# PyGIP +# PyGIP - GNN Ownership Verification Module -[![PyPI - Version](https://img.shields.io/pypi/v/PyGIP)](https://pypi.org/project/PyGIP) -[![Build Status](https://img.shields.io/github/actions/workflow/status/LabRAI/PyGIP/docs.yml)](https://github.com/LabRAI/PyGIP/actions) -[![License](https://img.shields.io/github/license/LabRAI/PyGIP.svg)](https://github.com/LabRAI/PyGIP/blob/main/LICENSE) -[![PyPI - Downloads](https://img.shields.io/pypi/dm/pygip)](https://github.com/LabRAI/PyGIP) -[![Issues](https://img.shields.io/github/issues/LabRAI/PyGIP)](https://github.com/LabRAI/PyGIP) -[![Pull Requests](https://img.shields.io/github/issues-pr/LabRAI/PyGIP)](https://github.com/LabRAI/PyGIP) -[![Stars](https://img.shields.io/github/stars/LabRAI/PyGIP)](https://github.com/LabRAI/PyGIP) -[![GitHub forks](https://img.shields.io/github/forks/LabRAI/PyGIP)](https://github.com/LabRAI/PyGIP) +This repository contains the integration of **Graph Neural Network (GNN) ownership verification** experiments into the PyGIP framework. It provides modular and extensible implementations for attacks and defenses on GNNs, following the guidelines of the PyGIP framework. -PyGIP is a Python library designed for experimenting with graph-based model extraction attacks and defenses. It provides -a modular framework to implement and test attack and defense strategies on graph datasets. +--- -## How to Cite +## 📋 Overview -If you find it useful, please considering cite the following work: +This module allows users to: -```bibtex -@article{li2025intellectual, - title={Intellectual Property in Graph-Based Machine Learning as a Service: Attacks and Defenses}, - author={Li, Lincan and Shen, Bolin and Zhao, Chenxi and Sun, Yuxiang and Zhao, Kaixiang and Pan, Shirui and Dong, Yushun}, - journal={arXiv preprint arXiv:2508.19641}, - year={2025} -} -``` +- Evaluate ownership verification on GNN models (GCN, GAT, GraphSAGE). +- Run experiments under **inductive** and **transductive** settings. +- Easily extend the framework with new datasets, attacks, or defenses. +> **Note:** Large model weights (`benign_model.pth`) and result folders (`temp_results/`, `robustness_results/`) are excluded to keep the repository clean. -## Installation +--- -PyGIP supports both CPU and GPU environments. Make sure you have Python installed (version >= 3.8, <3.13). - -### Base Installation - -First, install the core package: - -```bash -pip install PyGIP -``` - -This will install PyGIP with minimal dependencies. - -### CPU Version - -```bash -pip install "PyGIP[torch,dgl]" \ - --index-url https://download.pytorch.org/whl/cpu \ - --extra-index-url https://pypi.org/simple \ - -f https://data.dgl.ai/wheels/repo.html -``` - -### GPU Version (CUDA 12.1) - -```bash -pip install "PyGIP[torch,dgl]" \ - --index-url https://download.pytorch.org/whl/cu121 \ - --extra-index-url https://pypi.org/simple \ - -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html -``` - -## Quick Start - -Here’s a simple example to launch a Model Extraction Attack using PyGIP: - -```python -from datasets import Cora -from models.attack import ModelExtractionAttack0 - -# Load the Cora dataset -dataset = Cora() - -# Initialize the attack with a sampling ratio of 0.25 -mea = ModelExtractionAttack0(dataset, 0.25) - -# Execute the attack -mea.attack() -``` - -This code loads the Cora dataset, initializes a basic model extraction attack (`ModelExtractionAttack0`), and runs the -attack with a specified sampling ratio. - -And a simple example to run a Defense method against Model Extraction Attack: - -```python -from datasets import Cora -from models.defense import RandomWM - -# Load the Cora dataset -dataset = Cora() - -# Initialize the attack with a sampling ratio of 0.25 -med = RandomWM(dataset, 0.25) - -# Execute the defense -med.defend() -``` - -which runs the Random Watermarking Graph to defend against MEA. - -If you want to use cuda, please set environment variable: - -```shell -export PYGIP_DEVICE=cuda:0 -``` - -## Implementation & Contributors Guideline - -Refer to [Implementation Guideline](.github/IMPLEMENTATION.md) - -Refer to [Contributors Guideline](.github/CONTRIBUTING.md) - -## License - -[BSD 2-Clause License](LICENSE) - -## Contact - -For questions or contributions, please contact blshen@fsu.edu. diff --git a/config/global_cfg.yaml b/config/global_cfg.yaml new file mode 100644 index 0000000..cce7809 --- /dev/null +++ b/config/global_cfg.yaml @@ -0,0 +1,13 @@ +target_model: gat +target_hidden_dims: [224, 128] +dataset: Cora +train_setting: "" +test_setting: 1 +embedding_dim: 128 +train_process: "train" +test_process: "test" +n_run: 3 + +train_save_root: ../temp_results/diff/model_states/ +test_save_root: ../temp_results/diff/model_states/ +res_path: ../temp_results/diff/res/ diff --git a/config/test_setting1.yaml b/config/test_setting1.yaml new file mode 100644 index 0000000..6a49358 --- /dev/null +++ b/config/test_setting1.yaml @@ -0,0 +1,4 @@ +model_arches: ["gat", "gcn", "sage"] +layer_dims: [96, 160, 224, 288, 352] +num_hidden_layers: [2] +num_model_per_arch: 10 \ No newline at end of file diff --git a/config/test_setting2.yaml b/config/test_setting2.yaml new file mode 100644 index 0000000..109046f --- /dev/null +++ b/config/test_setting2.yaml @@ -0,0 +1,4 @@ +model_arches: ["gat", "gcn", "sage"] +layer_dims: [128, 192, 256, 320, 384] +num_hidden_layers: [1, 3] +num_model_per_arch: 10 \ No newline at end of file diff --git a/config/test_setting3.yaml b/config/test_setting3.yaml new file mode 100644 index 0000000..0c805f6 --- /dev/null +++ b/config/test_setting3.yaml @@ -0,0 +1,4 @@ +model_arches: ["gin", 'sgc'] +layer_dims: [96, 160, 224, 288, 352] +num_hidden_layers: [2] +num_model_per_arch: 15 \ No newline at end of file diff --git a/config/test_setting4.yaml b/config/test_setting4.yaml new file mode 100644 index 0000000..82f0aec --- /dev/null +++ b/config/test_setting4.yaml @@ -0,0 +1,4 @@ +model_arches: ["gin", 'sgc'] +layer_dims: [128, 192, 256, 320, 384] +num_hidden_layers: [1, 3] +num_model_per_arch: 15 \ No newline at end of file diff --git a/config/train_setting.yaml b/config/train_setting.yaml new file mode 100644 index 0000000..1a58b02 --- /dev/null +++ b/config/train_setting.yaml @@ -0,0 +1,4 @@ +model_arches: ["gat", "gcn", "sage"] +layer_dims: [96, 160, 224, 288, 352] +num_hidden_layers: [2] +num_model_per_arch: 20 # per arch \ No newline at end of file diff --git a/data/Cora/raw/ind.cora.allx b/data/Cora/raw/ind.cora.allx new file mode 100644 index 0000000..44d53b1 Binary files /dev/null and b/data/Cora/raw/ind.cora.allx differ diff --git a/data/Cora/raw/ind.cora.ally b/data/Cora/raw/ind.cora.ally new file mode 100644 index 0000000..04fbd0b Binary files /dev/null and b/data/Cora/raw/ind.cora.ally differ diff --git a/data/Cora/raw/ind.cora.graph b/data/Cora/raw/ind.cora.graph new file mode 100644 index 0000000..4d3bf85 Binary files /dev/null and b/data/Cora/raw/ind.cora.graph differ diff --git a/data/Cora/raw/ind.cora.test.index b/data/Cora/raw/ind.cora.test.index new file mode 100644 index 0000000..ded8092 --- /dev/null +++ b/data/Cora/raw/ind.cora.test.index @@ -0,0 +1,1000 @@ +2692 +2532 +2050 +1715 +2362 +2609 +2622 +1975 +2081 +1767 +2263 +1725 +2588 +2259 +2357 +1998 +2574 +2179 +2291 +2382 +1812 +1751 +2422 +1937 +2631 +2510 +2378 +2589 +2345 +1943 +1850 +2298 +1825 +2035 +2507 +2313 +1906 +1797 +2023 +2159 +2495 +1886 +2122 +2369 +2461 +1925 +2565 +1858 +2234 +2000 +1846 +2318 +1723 +2559 +2258 +1763 +1991 +1922 +2003 +2662 +2250 +2064 +2529 +1888 +2499 +2454 +2320 +2287 +2203 +2018 +2002 +2632 +2554 +2314 +2537 +1760 +2088 +2086 +2218 +2605 +1953 +2403 +1920 +2015 +2335 +2535 +1837 +2009 +1905 +2636 +1942 +2193 +2576 +2373 +1873 +2463 +2509 +1954 +2656 +2455 +2494 +2295 +2114 +2561 +2176 +2275 +2635 +2442 +2704 +2127 +2085 +2214 +2487 +1739 +2543 +1783 +2485 +2262 +2472 +2326 +1738 +2170 +2100 +2384 +2152 +2647 +2693 +2376 +1775 +1726 +2476 +2195 +1773 +1793 +2194 +2581 +1854 +2524 +1945 +1781 +1987 +2599 +1744 +2225 +2300 +1928 +2042 +2202 +1958 +1816 +1916 +2679 +2190 +1733 +2034 +2643 +2177 +1883 +1917 +1996 +2491 +2268 +2231 +2471 +1919 +1909 +2012 +2522 +1865 +2466 +2469 +2087 +2584 +2563 +1924 +2143 +1736 +1966 +2533 +2490 +2630 +1973 +2568 +1978 +2664 +2633 +2312 +2178 +1754 +2307 +2480 +1960 +1742 +1962 +2160 +2070 +2553 +2433 +1768 +2659 +2379 +2271 +1776 +2153 +1877 +2027 +2028 +2155 +2196 +2483 +2026 +2158 +2407 +1821 +2131 +2676 +2277 +2489 +2424 +1963 +1808 +1859 +2597 +2548 +2368 +1817 +2405 +2413 +2603 +2350 +2118 +2329 +1969 +2577 +2475 +2467 +2425 +1769 +2092 +2044 +2586 +2608 +1983 +2109 +2649 +1964 +2144 +1902 +2411 +2508 +2360 +1721 +2005 +2014 +2308 +2646 +1949 +1830 +2212 +2596 +1832 +1735 +1866 +2695 +1941 +2546 +2498 +2686 +2665 +1784 +2613 +1970 +2021 +2211 +2516 +2185 +2479 +2699 +2150 +1990 +2063 +2075 +1979 +2094 +1787 +2571 +2690 +1926 +2341 +2566 +1957 +1709 +1955 +2570 +2387 +1811 +2025 +2447 +2696 +2052 +2366 +1857 +2273 +2245 +2672 +2133 +2421 +1929 +2125 +2319 +2641 +2167 +2418 +1765 +1761 +1828 +2188 +1972 +1997 +2419 +2289 +2296 +2587 +2051 +2440 +2053 +2191 +1923 +2164 +1861 +2339 +2333 +2523 +2670 +2121 +1921 +1724 +2253 +2374 +1940 +2545 +2301 +2244 +2156 +1849 +2551 +2011 +2279 +2572 +1757 +2400 +2569 +2072 +2526 +2173 +2069 +2036 +1819 +1734 +1880 +2137 +2408 +2226 +2604 +1771 +2698 +2187 +2060 +1756 +2201 +2066 +2439 +1844 +1772 +2383 +2398 +1708 +1992 +1959 +1794 +2426 +2702 +2444 +1944 +1829 +2660 +2497 +2607 +2343 +1730 +2624 +1790 +1935 +1967 +2401 +2255 +2355 +2348 +1931 +2183 +2161 +2701 +1948 +2501 +2192 +2404 +2209 +2331 +1810 +2363 +2334 +1887 +2393 +2557 +1719 +1732 +1986 +2037 +2056 +1867 +2126 +1932 +2117 +1807 +1801 +1743 +2041 +1843 +2388 +2221 +1833 +2677 +1778 +2661 +2306 +2394 +2106 +2430 +2371 +2606 +2353 +2269 +2317 +2645 +2372 +2550 +2043 +1968 +2165 +2310 +1985 +2446 +1982 +2377 +2207 +1818 +1913 +1766 +1722 +1894 +2020 +1881 +2621 +2409 +2261 +2458 +2096 +1712 +2594 +2293 +2048 +2359 +1839 +2392 +2254 +1911 +2101 +2367 +1889 +1753 +2555 +2246 +2264 +2010 +2336 +2651 +2017 +2140 +1842 +2019 +1890 +2525 +2134 +2492 +2652 +2040 +2145 +2575 +2166 +1999 +2434 +1711 +2276 +2450 +2389 +2669 +2595 +1814 +2039 +2502 +1896 +2168 +2344 +2637 +2031 +1977 +2380 +1936 +2047 +2460 +2102 +1745 +2650 +2046 +2514 +1980 +2352 +2113 +1713 +2058 +2558 +1718 +1864 +1876 +2338 +1879 +1891 +2186 +2451 +2181 +2638 +2644 +2103 +2591 +2266 +2468 +1869 +2582 +2674 +2361 +2462 +1748 +2215 +2615 +2236 +2248 +2493 +2342 +2449 +2274 +1824 +1852 +1870 +2441 +2356 +1835 +2694 +2602 +2685 +1893 +2544 +2536 +1994 +1853 +1838 +1786 +1930 +2539 +1892 +2265 +2618 +2486 +2583 +2061 +1796 +1806 +2084 +1933 +2095 +2136 +2078 +1884 +2438 +2286 +2138 +1750 +2184 +1799 +2278 +2410 +2642 +2435 +1956 +2399 +1774 +2129 +1898 +1823 +1938 +2299 +1862 +2420 +2673 +1984 +2204 +1717 +2074 +2213 +2436 +2297 +2592 +2667 +2703 +2511 +1779 +1782 +2625 +2365 +2315 +2381 +1788 +1714 +2302 +1927 +2325 +2506 +2169 +2328 +2629 +2128 +2655 +2282 +2073 +2395 +2247 +2521 +2260 +1868 +1988 +2324 +2705 +2541 +1731 +2681 +2707 +2465 +1785 +2149 +2045 +2505 +2611 +2217 +2180 +1904 +2453 +2484 +1871 +2309 +2349 +2482 +2004 +1965 +2406 +2162 +1805 +2654 +2007 +1947 +1981 +2112 +2141 +1720 +1758 +2080 +2330 +2030 +2432 +2089 +2547 +1820 +1815 +2675 +1840 +2658 +2370 +2251 +1908 +2029 +2068 +2513 +2549 +2267 +2580 +2327 +2351 +2111 +2022 +2321 +2614 +2252 +2104 +1822 +2552 +2243 +1798 +2396 +2663 +2564 +2148 +2562 +2684 +2001 +2151 +2706 +2240 +2474 +2303 +2634 +2680 +2055 +2090 +2503 +2347 +2402 +2238 +1950 +2054 +2016 +1872 +2233 +1710 +2032 +2540 +2628 +1795 +2616 +1903 +2531 +2567 +1946 +1897 +2222 +2227 +2627 +1856 +2464 +2241 +2481 +2130 +2311 +2083 +2223 +2284 +2235 +2097 +1752 +2515 +2527 +2385 +2189 +2283 +2182 +2079 +2375 +2174 +2437 +1993 +2517 +2443 +2224 +2648 +2171 +2290 +2542 +2038 +1855 +1831 +1759 +1848 +2445 +1827 +2429 +2205 +2598 +2657 +1728 +2065 +1918 +2427 +2573 +2620 +2292 +1777 +2008 +1875 +2288 +2256 +2033 +2470 +2585 +2610 +2082 +2230 +1915 +1847 +2337 +2512 +2386 +2006 +2653 +2346 +1951 +2110 +2639 +2520 +1939 +2683 +2139 +2220 +1910 +2237 +1900 +1836 +2197 +1716 +1860 +2077 +2519 +2538 +2323 +1914 +1971 +1845 +2132 +1802 +1907 +2640 +2496 +2281 +2198 +2416 +2285 +1755 +2431 +2071 +2249 +2123 +1727 +2459 +2304 +2199 +1791 +1809 +1780 +2210 +2417 +1874 +1878 +2116 +1961 +1863 +2579 +2477 +2228 +2332 +2578 +2457 +2024 +1934 +2316 +1841 +1764 +1737 +2322 +2239 +2294 +1729 +2488 +1974 +2473 +2098 +2612 +1834 +2340 +2423 +2175 +2280 +2617 +2208 +2560 +1741 +2600 +2059 +1747 +2242 +2700 +2232 +2057 +2147 +2682 +1792 +1826 +2120 +1895 +2364 +2163 +1851 +2391 +2414 +2452 +1803 +1989 +2623 +2200 +2528 +2415 +1804 +2146 +2619 +2687 +1762 +2172 +2270 +2678 +2593 +2448 +1882 +2257 +2500 +1899 +2478 +2412 +2107 +1746 +2428 +2115 +1800 +1901 +2397 +2530 +1912 +2108 +2206 +2091 +1740 +2219 +1976 +2099 +2142 +2671 +2668 +2216 +2272 +2229 +2666 +2456 +2534 +2697 +2688 +2062 +2691 +2689 +2154 +2590 +2626 +2390 +1813 +2067 +1952 +2518 +2358 +1789 +2076 +2049 +2119 +2013 +2124 +2556 +2105 +2093 +1885 +2305 +2354 +2135 +2601 +1770 +1995 +2504 +1749 +2157 diff --git a/data/Cora/raw/ind.cora.tx b/data/Cora/raw/ind.cora.tx new file mode 100644 index 0000000..6e856d7 Binary files /dev/null and b/data/Cora/raw/ind.cora.tx differ diff --git a/data/Cora/raw/ind.cora.ty b/data/Cora/raw/ind.cora.ty new file mode 100644 index 0000000..da1734a Binary files /dev/null and b/data/Cora/raw/ind.cora.ty differ diff --git a/data/Cora/raw/ind.cora.x b/data/Cora/raw/ind.cora.x new file mode 100644 index 0000000..c4a91d0 Binary files /dev/null and b/data/Cora/raw/ind.cora.x differ diff --git a/data/Cora/raw/ind.cora.y b/data/Cora/raw/ind.cora.y new file mode 100644 index 0000000..58e30ef Binary files /dev/null and b/data/Cora/raw/ind.cora.y differ diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..e69de29 diff --git a/eval/README.md b/eval/README.md new file mode 100644 index 0000000..e69de29 diff --git a/eval/eval1.py b/eval/eval1.py new file mode 100644 index 0000000..f2a881f --- /dev/null +++ b/eval/eval1.py @@ -0,0 +1,181 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score +import yaml +import warnings +warnings.filterwarnings("ignore") + +# Ensure src/ is in PYTHONPATH +import sys +sys.path.append(os.path.join(os.getcwd(), "src")) + +from models.gnn import GCN, GAT, GraphSAGE +from models.ownership_classifier import OwnershipClassifier +from datasets.datareader import create_splits +from datasets.graph_operator import apply_masking + +# ----------------------------- +# Seed for reproducibility +# ----------------------------- +torch.manual_seed(42) +np.random.seed(42) + +# ----------------------------- +# Config Loader +# ----------------------------- +def load_config(config_path="config/global_cfg.yaml"): + if not os.path.exists(config_path): + raise FileNotFoundError(f"Config file not found: {config_path}") + with open(config_path, 'r') as f: + return yaml.safe_load(f) + +# ----------------------------- +# GNN Evaluation +# ----------------------------- +def evaluate_gnn(model, x, edge_index, mask, data): + model.eval() + with torch.no_grad(): + out = model(x, edge_index) + pred = out.argmax(dim=1) + return accuracy_score(data.y[mask].cpu(), pred[mask].cpu()) + +# ----------------------------- +# Extract model posteriors +# ----------------------------- +def get_posteriors(model, x, edge_index, mask): + model.eval() + with torch.no_grad(): + out = model(x, edge_index) + return out[mask].detach().cpu().flatten() + +# ----------------------------- +# Ownership Classifier +# ----------------------------- +def train_classifier(X, y, input_dim, hidden_dim=64, epochs=100, lr=0.01, device="cpu"): + classifier = OwnershipClassifier(input_dim, hidden_dim).to(device) + optimizer = torch.optim.Adam(classifier.parameters(), lr=lr) + X, y = X.to(device), y.to(device) + for _ in range(epochs): + classifier.train() + optimizer.zero_grad() + out = classifier(X).squeeze() + loss = F.binary_cross_entropy(out, y) + loss.backward() + optimizer.step() + return classifier + +def evaluate_classifier(classifier, X, y, device="cpu"): + classifier.eval() + X, y = X.to(device), y.to(device) + with torch.no_grad(): + pred = (classifier(X) > 0.5).float() + acc = accuracy_score(y.cpu(), pred.cpu()) + # False positive and false negative rates + fp = ((pred == 1) & (y == 0)).sum().float() / max((y == 0).sum().float(), 1) + fn = ((pred == 0) & (y == 1)).sum().float() / max((y == 1).sum().float(), 1) + return acc, fp.item(), fn.item() + +# ----------------------------- +# Main Evaluation +# ----------------------------- +def main(): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using device: {device}") + + # Load config + cfg = load_config() + dataset_name = cfg['dataset']['name'] + architectures = cfg['model']['architectures'] + mask_ratios = { + 'inductive': cfg['training']['mask_ratio_inductive'], + 'transductive': cfg['training']['mask_ratio_transductive'] + } + + # Load dataset + dataset = Planetoid(root='./data', name=dataset_name) + data = dataset[0].to(device) + num_classes = dataset.num_classes + num_features = dataset.num_features + + results = [] + + for mode in ['inductive', 'transductive']: + mask_ratio = mask_ratios[mode] + masks = create_splits(data, mode=mode) + masked_x = apply_masking(data.x, mask_ratio=mask_ratio).to(device) + + for arch in architectures: + model_class = {'gcn': GCN, 'gat': GAT, 'sage': GraphSAGE}[arch] + + for setting in ['I', 'II', 'III', 'IV']: + target_path = f"temp_results/diff/model_states/{dataset_name}/{mode}/mask_models/random_mask/1.0_{mask_ratio}/{arch}_224_128.pt" + if not os.path.exists(target_path): + print(f"Skipping {arch} {mode} Setting {setting}: Target model not found") + continue + + # Load target model + target_model = model_class(num_features, 224, num_classes).to(device) + target_model.load_state_dict(torch.load(target_path, map_location=device)) + target_acc = evaluate_gnn(target_model, masked_x, data.edge_index, masks['test'], data) + + # Collect posteriors + posteriors = [] + labels = [] + + # Independent models + ind_dir = f"temp_results/diff/model_states/{dataset_name}/{mode}/independent_models" + if os.path.exists(ind_dir): + for f in os.listdir(ind_dir): + if arch in f and f.endswith('.pt'): + model = model_class(num_features, 224, num_classes).to(device) + model.load_state_dict(torch.load(os.path.join(ind_dir, f), map_location=device)) + post = get_posteriors(model, masked_x, data.edge_index, masks['train']) + posteriors.append(post) + labels.append(0) + + # Surrogate models + surr_dir = f"temp_results/diff/model_states/{dataset_name}/{mode}/extraction_models/random_mask/{arch}_224_128/1.0_{mask_ratio}" + if os.path.exists(surr_dir): + for f in os.listdir(surr_dir): + if arch in f and f.endswith('.pt'): + model = model_class(num_features, 224, num_classes).to(device) + model.load_state_dict(torch.load(os.path.join(surr_dir, f), map_location=device)) + post = get_posteriors(model, masked_x, data.edge_index, masks['train']) + posteriors.append(post) + labels.append(1) + + if len(posteriors) < 2: + print(f"Skipping {arch} {mode} Setting {setting}: Insufficient models") + continue + + # Train and evaluate classifier + X = torch.stack(posteriors).to(device) + y = torch.tensor(labels, dtype=torch.float32).to(device) + classifier = train_classifier(X, y, X.shape[1], device=device) + ver_acc, fpr, fnr = evaluate_classifier(classifier, X, y, device=device) + + results.append({ + 'dataset': dataset_name, + 'mode': mode, + 'model': arch, + 'setting': setting, + 'target_acc': target_acc, + 'ver_acc': ver_acc, + 'fpr': fpr, + 'fnr': fnr + }) + + # Save results + os.makedirs('experiments/results', exist_ok=True) + df = pd.DataFrame(results) + df.to_csv('experiments/results/results.csv', index=False) + print("Results saved to experiments/results/results.csv") + +if __name__ == "__main__": + main() diff --git a/eval/evaluate.py b/eval/evaluate.py new file mode 100644 index 0000000..5130808 --- /dev/null +++ b/eval/evaluate.py @@ -0,0 +1,603 @@ +import os +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.datasets import Planetoid +from torch_geometric.nn import GCNConv, GATConv, SAGEConv +from torch.nn import Linear +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.metrics import accuracy_score +import yaml +import copy +import sys +from matplotlib.ticker import PercentFormatter + +# Ensure project root is in sys.path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Set seed for reproducibility +torch.manual_seed(42) +np.random.seed(42) + +def load_config(config_path): + with open(config_path, 'r') as f: + return yaml.safe_load(f) + +def parse_dims_from_filename(filename): + """ + Parse hidden dimensions from filenames like 'gcn_224_128.pt' + Returns list of dimensions + """ + numbers = re.findall(r'\d+', filename) + return [int(num) for num in numbers] if numbers else [128, 128] + +def parse_architecture_from_filename(filename): + """Parse architecture from filename""" + if 'gcn' in filename: + return 'gcn' + elif 'gat' in filename: + return 'gat' + elif 'sage' in filename: + return 'sage' + return 'gcn' # default + +# ------------------- Flexible Model Definitions ------------------- +class FlexibleGCN(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dims): + super(FlexibleGCN, self).__init__() + if isinstance(hidden_dims, int): + hidden_dims = [hidden_dims] + + self.layers = torch.nn.ModuleList() + dims = [in_dim] + hidden_dims + + for i in range(len(dims) - 1): + self.layers.append(GCNConv(dims[i], dims[i + 1])) + + self.fc = Linear(dims[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = F.relu(layer(x, edge_index)) + embedding = x + x = self.fc(x) + return embedding, x + +class FlexibleGAT(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dims, heads=8): + super(FlexibleGAT, self).__init__() + if isinstance(hidden_dims, int): + hidden_dims = [hidden_dims] + + self.layers = torch.nn.ModuleList() + dims = [in_dim] + hidden_dims + + for i in range(len(dims) - 1): + if i == 0: + # First layer + self.layers.append(GATConv(dims[i], dims[i + 1] // heads, heads=heads, concat=True)) + elif i == len(dims) - 2: + # Last layer + self.layers.append(GATConv(dims[i], dims[i + 1], heads=1, concat=False)) + else: + # Middle layers + self.layers.append(GATConv(dims[i], dims[i + 1] // heads, heads=heads, concat=True)) + + self.fc = Linear(dims[-1], out_dim) + self.heads = heads + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = F.relu(layer(x, edge_index)) + embedding = x + x = self.fc(x) + return embedding, x + +class FlexibleGraphSage(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dims): + super(FlexibleGraphSage, self).__init__() + if isinstance(hidden_dims, int): + hidden_dims = [hidden_dims] + + self.layers = torch.nn.ModuleList() + dims = [in_dim] + hidden_dims + + for i in range(len(dims) - 1): + self.layers.append(SAGEConv(dims[i], dims[i + 1])) + + self.fc = Linear(dims[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = F.relu(layer(x, edge_index)) + embedding = x + x = self.fc(x) + return embedding, x + +def create_flexible_model(arch, in_dim, out_dim, filename): + """ + Create a model with flexible architecture based on filename + """ + dims = parse_dims_from_filename(filename) + hidden_dims = dims[:-1] if len(dims) > 1 else dims + final_dim = dims[-1] if dims else out_dim + + print(f"Creating {arch} model with hidden_dims={hidden_dims}, final_dim={final_dim}") + + if arch == 'gcn': + return FlexibleGCN(in_dim, out_dim, hidden_dims) + elif arch == 'gat': + return FlexibleGAT(in_dim, out_dim, hidden_dims, heads=8) + elif arch == 'sage': + return FlexibleGraphSage(in_dim, out_dim, hidden_dims) + +def evaluate_gnn(model, x, edge_index, test_mask, data): + model.eval() + with torch.no_grad(): + _, out = model((x, edge_index)) + pred = out.argmax(dim=1) + return accuracy_score(data.y[test_mask].cpu(), pred[test_mask].cpu()) + +def get_posteriors(model, x, edge_index, mask): + model.eval() + with torch.no_grad(): + _, out = model((x, edge_index)) + selected = out[mask] + return selected.detach().cpu() + +def train_classifier(X, y, input_dim, hidden_layers=None, epochs=100, lr=0.01): + if hidden_layers is None: + hidden_layers = [64] + + class MLPClassifier(nn.Module): + def __init__(self, input_dim, hidden_layers): + super(MLPClassifier, self).__init__() + layers = [] + prev_dim = input_dim + for hidden_dim in hidden_layers: + layers.append(nn.Linear(prev_dim, hidden_dim)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(0.1)) + prev_dim = hidden_dim + layers.append(nn.Linear(prev_dim, 2)) + self.network = nn.Sequential(*layers) + + def forward(self, x): + return self.network(x) + + classifier = MLPClassifier(input_dim, hidden_layers) + optimizer = torch.optim.Adam(classifier.parameters(), lr=lr, weight_decay=1e-4) + loss_fn = torch.nn.CrossEntropyLoss() + X = X.float() + y = y.long() + + for epoch in range(epochs): + classifier.train() + optimizer.zero_grad() + logits = classifier(X) + loss = loss_fn(logits, y) + loss.backward() + optimizer.step() + + return classifier + +def evaluate_classifier(classifier, X, y): + classifier.eval() + with torch.no_grad(): + logits = classifier(X.float()) + preds = logits.argmax(dim=1) + y_long = y.long() + acc = accuracy_score(y_long.cpu(), preds.cpu()) + neg_mask = (y_long == 0) + pos_mask = (y_long == 1) + fp = ((preds == 1) & neg_mask).sum().float() / neg_mask.sum().float() if neg_mask.sum().item() > 0 else torch.tensor(0.0) + fn = ((preds == 0) & pos_mask).sum().float() / pos_mask.sum().float() if pos_mask.sum().item() > 0 else torch.tensor(0.0) + return float(acc), float(fp), float(fn) + +def safe_load_state(model, path): + """ + Load state_dict safely, handling pickle issues + """ + if not os.path.exists(path): + print(f"File not found: {path}") + return False + + try: + # Try to load with weights_only first (safer) + try: + state = torch.load(path, map_location='cpu', weights_only=True) + except: + # Fallback to non-weights_only + state = torch.load(path, map_location='cpu', weights_only=False) + + # Handle different checkpoint formats + if isinstance(state, torch.nn.Module): + state = state.state_dict() + elif isinstance(state, dict): + for key in ['state_dict', 'model_state', 'model']: + if key in state: + state = state[key] + break + + # Load compatible parameters + model_state = model.state_dict() + filtered_state = {} + + for key, value in state.items(): + if key in model_state and value.shape == model_state[key].shape: + filtered_state[key] = value + else: + # Try to find matching parameter by pattern + for model_key in model_state.keys(): + if (key.endswith('.weight') and model_key.endswith('.weight') and + value.shape == model_state[model_key].shape): + filtered_state[model_key] = value + break + elif (key.endswith('.bias') and model_key.endswith('.bias') and + value.shape == model_state[model_key].shape): + filtered_state[model_key] = value + break + + model.load_state_dict(filtered_state, strict=False) + print(f"Loaded {len(filtered_state)}/{len(state)} parameters from {path}") + return True + + except Exception as e: + print(f"Failed to load {path}: {str(e)}") + return False + +def simple_mask_graph_data(args, data): + """ + Simple masking function for evaluation + """ + masked_data = copy.deepcopy(data) + if args.mask_feat_ratio > 0: + mask = torch.rand_like(masked_data.x) < args.mask_feat_ratio + masked_data.x[mask] = 0 + return masked_data + +def pad_posteriors(posteriors, target_size): + """Pad or truncate posteriors to target size""" + padded = [] + for p in posteriors: + if p.numel() < target_size: + # Pad with zeros + pad_size = target_size - p.numel() + padded.append(torch.cat([p, torch.zeros(pad_size)])) + elif p.numel() > target_size: + # Truncate + padded.append(p[:target_size]) + else: + padded.append(p) + return torch.stack(padded) + +# ------------------- Plotting Functions ------------------- +def plot_verification_accuracy_by_architecture(df, dataset_name): + """Plot verification accuracy by architecture type""" + plt.figure(figsize=(10, 6)) + + # Group by architecture and calculate mean verification accuracy + arch_results = df.groupby(['model', 'mode'])['ver_acc'].mean().reset_index() + + # Create bar plot + ax = sns.barplot(x='model', y='ver_acc', hue='mode', data=arch_results) + + plt.title(f'Verification Accuracy by Architecture on {dataset_name}') + plt.xlabel('Model Architecture') + plt.ylabel('Verification Accuracy') + plt.ylim(0, 1) + + # Format y-axis as percentage + plt.gca().yaxis.set_major_formatter(PercentFormatter(1.0)) + + # Add value labels on bars + for p in ax.patches: + ax.annotate(f'{p.get_height():.1%}', + (p.get_x() + p.get_width() / 2., p.get_height()), + ha='center', va='center', xytext=(0, 10), + textcoords='offset points') + + plt.legend(title='Mode') + plt.tight_layout() + plt.savefig(f'experiments/results/verification_accuracy_by_architecture.png') + plt.close() + print("Plot saved: verification_accuracy_by_architecture.png") + +def plot_target_accuracy_vs_mask_ratio(df, dataset_name): + """Plot target accuracy vs mask ratio""" + plt.figure(figsize=(10, 6)) + + # Extract mask ratio from setting and add to dataframe + df['mask_ratio'] = df['mode'].map({'inductive': 0.05, 'transductive': 0.1}) + + # Group by architecture and mask ratio + accuracy_results = df.groupby(['model', 'mask_ratio'])['target_acc'].mean().reset_index() + + # Create line plot + sns.lineplot(x='mask_ratio', y='target_acc', hue='model', + style='model', markers=True, data=accuracy_results) + + plt.title(f'Target Accuracy vs Mask Ratio on {dataset_name}') + plt.xlabel('Mask Ratio') + plt.ylabel('Target Accuracy') + plt.ylim(0, 1) + plt.grid(True, alpha=0.3) + + # Format y-axis as percentage + plt.gca().yaxis.set_major_formatter(PercentFormatter(1.0)) + + plt.legend(title='Architecture') + plt.tight_layout() + plt.savefig(f'experiments/results/target_accuracy_vs_mask_ratio.png') + plt.close() + print("Plot saved: target_accuracy_vs_mask_ratio.png") + +def plot_verification_performance_by_setting(df, dataset_name): + """Plot verification performance across different settings""" + plt.figure(figsize=(12, 7)) + + # Extract configuration from filename + df['config'] = df['setting'].str.extract(r'(\d+_\d+)') + + # Group by architecture and configuration + config_results = df.groupby(['model', 'config'])['ver_acc'].mean().reset_index() + + # Create grouped bar chart + ax = sns.barplot(x='config', y='ver_acc', hue='model', data=config_results) + + plt.title(f'Verification Performance by Model Configuration on {dataset_name}') + plt.xlabel('Model Configuration (Hidden_Output dimensions)') + plt.ylabel('Verification Accuracy') + plt.ylim(0, 1) + + # Format y-axis as percentage + plt.gca().yaxis.set_major_formatter(PercentFormatter(1.0)) + + # Rotate x-axis labels for better readability + plt.xticks(rotation=45) + + plt.legend(title='Architecture') + plt.tight_layout() + plt.savefig(f'experiments/results/verification_performance_by_setting.png') + plt.close() + print("Plot saved: verification_performance_by_setting.png") + +def plot_false_rates(df, dataset_name): + """Plot false positive and false negative rates""" + plt.figure(figsize=(10, 6)) + + # Calculate mean false rates by architecture + false_rates = df.groupby('model')[['fpr', 'fnr']].mean().reset_index() + + # Convert to long format for plotting + false_rates_long = false_rates.melt(id_vars='model', + value_vars=['fpr', 'fnr'], + var_name='rate_type', + value_name='rate') + + # Create stacked bar chart + ax = sns.barplot(x='model', y='rate', hue='rate_type', data=false_rates_long) + + plt.title(f'False Positive/Negative Rates by Architecture on {dataset_name}') + plt.xlabel('Model Architecture') + plt.ylabel('Rate') + plt.ylim(0, 0.5) # Assuming rates are between 0-0.5 + + # Format y-axis as percentage + plt.gca().yaxis.set_major_formatter(PercentFormatter(1.0)) + + # Add value labels on bars + for p in ax.patches: + ax.annotate(f'{p.get_height():.1%}', + (p.get_x() + p.get_width() / 2., p.get_height() / 2), + ha='center', va='center', xytext=(0, 0), + textcoords='offset points', color='white', weight='bold') + + plt.legend(title='Rate Type') + plt.tight_layout() + plt.savefig(f'experiments/results/false_rates.png') + plt.close() + print("Plot saved: false_rates.png") + +def plot_comparative_performance(df, dataset_name): + """Plot comparative performance: original vs verification accuracy""" + plt.figure(figsize=(12, 6)) + + # Calculate mean accuracy by architecture + performance = df.groupby('model')[['target_acc', 'ver_acc']].mean().reset_index() + + # Create dual-axis plot + fig, ax1 = plt.subplots(figsize=(12, 6)) + + # Bar width + x = np.arange(len(performance['model'])) + width = 0.35 + + # Plot target accuracy + bars1 = ax1.bar(x - width/2, performance['target_acc'], width, + label='Target Accuracy', alpha=0.7, color='skyblue') + ax1.set_xlabel('Model Architecture') + ax1.set_ylabel('Target Accuracy', color='skyblue') + ax1.tick_params(axis='y', labelcolor='skyblue') + ax1.set_ylim(0, 1) + ax1.yaxis.set_major_formatter(PercentFormatter(1.0)) + + # Create second y-axis + ax2 = ax1.twinx() + bars2 = ax2.bar(x + width/2, performance['ver_acc'], width, + label='Verification Accuracy', alpha=0.7, color='salmon') + ax2.set_ylabel('Verification Accuracy', color='salmon') + ax2.tick_params(axis='y', labelcolor='salmon') + ax2.set_ylim(0, 1) + ax2.yaxis.set_major_formatter(PercentFormatter(1.0)) + + # Set x-axis labels + ax1.set_xticks(x) + ax1.set_xticklabels(performance['model']) + + # Add legend + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper left') + + plt.title(f'Comparative Performance: Target vs Verification Accuracy on {dataset_name}') + plt.tight_layout() + plt.savefig(f'experiments/results/comparative_performance.png') + plt.close() + print("Plot saved: comparative_performance.png") + +def plot_all_results(df, dataset_name): + """Generate all plots""" + if df.empty: + print("No data to plot") + return + + print("\nGenerating all plots...") + + # Create all plots + plot_verification_accuracy_by_architecture(df, dataset_name) + plot_target_accuracy_vs_mask_ratio(df, dataset_name) + plot_verification_performance_by_setting(df, dataset_name) + plot_false_rates(df, dataset_name) + plot_comparative_performance(df, dataset_name) + + print("All plots generated successfully!") + +def main(): + global_cfg = load_config('config/global_cfg.yaml') + dataset_name = global_cfg['dataset'] + architectures = ['gcn', 'gat', 'sage'] + mask_ratios = {'inductive': 0.05, 'transductive': 0.1} + + dataset = Planetoid(root='./data', name=dataset_name) + data = dataset[0] + num_classes = dataset.num_classes + num_features = dataset.num_features + + results = [] + + for mode in ['inductive', 'transductive']: + mask_ratio = mask_ratios[mode] + masks = {'train': data.train_mask, 'val': data.val_mask, 'test': data.test_mask} + + class MaskArgs: + mask_node_ratio = 0 + mask_feat_ratio = mask_ratio + mask_feat_type = 'random_mask' + mask_method = 'fix' + mask_node_type = 'overall' + feature_random_seed = 42 + task_type = mode + + for arch in architectures: + # Get all target model files for the current architecture + target_dir = f"temp_results/diff/model_states/{dataset_name}/{mode}/mask_models/random_mask/1.0_{mask_ratio}" + if not os.path.exists(target_dir): + print(f"Directory not found: {target_dir}") + continue + + target_files = [f for f in os.listdir(target_dir) if f.startswith(arch) and f.endswith('.pt')] + + for target_file in target_files: + target_path = os.path.join(target_dir, target_file) + print(f"\nProcessing: {target_path}") + + # Create flexible model + target_model = create_flexible_model(arch, num_features, num_classes, target_file) + + if not safe_load_state(target_model, target_path): + print(f"Skipping {arch} {mode} {target_file}: Failed to load") + continue + + # Use simple masking + masked_data = simple_mask_graph_data(MaskArgs(), data) + masked_x = masked_data.x + + target_acc = evaluate_gnn(target_model, masked_x, data.edge_index, masks['test'], data) + print(f"Target accuracy: {target_acc:.4f}") + + posteriors = [] + labels = [] + + # Independent models + ind_dir = f"temp_results/diff/model_states/{dataset_name}/{mode}/independent_models" + if os.path.exists(ind_dir): + for fname in os.listdir(ind_dir): + if arch in fname and fname.endswith('.pt'): + file_arch = parse_architecture_from_filename(fname) + ind_model = create_flexible_model(file_arch, num_features, num_classes, fname) + if safe_load_state(ind_model, os.path.join(ind_dir, fname)): + post = get_posteriors(ind_model, masked_x, data.edge_index, masks['train']) + if post.numel() > 0: + posteriors.append(post.flatten()) + labels.append(0) + print(f"Loaded independent model: {fname}") + + # Surrogate models - skip for now due to import issues + # hidden_dims = parse_dims_from_filename(target_file) + # if len(hidden_dims) >= 2: + # surr_dir = f"temp_results/diff/model_states/{dataset_name}/{mode}/extraction_models/random_mask/{arch}_{hidden_dims[0]}_{hidden_dims[-1]}/1.0_{mask_ratio}" + # if os.path.exists(surr_dir): + # for fname in os.listdir(surr_dir): + # if fname.endswith('.pt'): + # file_arch = parse_architecture_from_filename(fname) + # surr_model = create_flexible_model(file_arch, num_features, num_classes, fname) + # if safe_load_state(surr_model, os.path.join(surr_dir, fname)): + # post = get_posteriors(surr_model, masked_x, data.edge_index, masks['train']) + # if post.numel() > 0: + # posteriors.append(post.flatten()) + # labels.append(1) + # print(f"Loaded surrogate model: {fname}") + + if len(posteriors) < 2: + print(f"Skipping {arch} {mode} {target_file}: Insufficient models ({len(posteriors)})") + continue + + # Find minimum size and pad all posteriors + min_size = min(p.numel() for p in posteriors) + X = pad_posteriors(posteriors, min_size) + y = torch.tensor(labels, dtype=torch.long) + + if X.shape[0] != y.shape[0]: + print(f"Shape mismatch: X={X.shape}, y={y.shape}") + continue + + try: + classifier = train_classifier(X, y, X.shape[1], hidden_layers=[64], epochs=100, lr=0.01) + ver_acc, fpr, fnr = evaluate_classifier(classifier, X, y) + + results.append({ + 'dataset': dataset_name, + 'mode': mode, + 'model': arch, + 'setting': target_file, + 'target_acc': target_acc, + 'ver_acc': ver_acc, + 'fpr': fpr, + 'fnr': fnr + }) + print(f"Result: {arch}, {mode}, {target_file}, Target Acc: {target_acc:.4f}, Ver Acc: {ver_acc:.4f}") + except Exception as e: + print(f"Error training classifier: {e}") + + # Save results and generate plots + if results: + df = pd.DataFrame(results) + os.makedirs('experiments/results', exist_ok=True) + df.to_csv('experiments/results/results.csv', index=False) + print(f"Results saved to experiments/results/results.csv") + + # Generate all comprehensive plots + plot_all_results(df, dataset_name) + else: + print("No results to save or plot") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/results/README.md b/experiments/results/README.md new file mode 100644 index 0000000..e69de29 diff --git a/experiments/results/__init__.py b/experiments/results/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/results/comparative_performance.png b/experiments/results/comparative_performance.png new file mode 100644 index 0000000..b96ae60 Binary files /dev/null and b/experiments/results/comparative_performance.png differ diff --git a/experiments/results/false_rates.png b/experiments/results/false_rates.png new file mode 100644 index 0000000..c088d17 Binary files /dev/null and b/experiments/results/false_rates.png differ diff --git a/experiments/results/results.csv b/experiments/results/results.csv new file mode 100644 index 0000000..df608a3 --- /dev/null +++ b/experiments/results/results.csv @@ -0,0 +1,19 @@ +dataset,mode,model,setting,target_acc,ver_acc,fpr,fnr +Cora,inductive,gcn,gcn_352_128.pt,0.256,1.0,0.0,0.0 +Cora,inductive,gcn,gcn_288_128.pt,0.101,1.0,0.0,0.0 +Cora,inductive,gcn,gcn_224_128.pt,0.236,1.0,0.0,0.0 +Cora,inductive,gat,gat_288_128.pt,0.103,1.0,0.0,0.0 +Cora,inductive,gat,gat_352_128.pt,0.164,1.0,0.0,0.0 +Cora,inductive,gat,gat_224_128.pt,0.102,1.0,0.0,0.0 +Cora,inductive,sage,sage_224_128.pt,0.079,1.0,0.0,0.0 +Cora,inductive,sage,sage_288_128.pt,0.089,1.0,0.0,0.0 +Cora,inductive,sage,sage_352_128.pt,0.132,1.0,0.0,0.0 +Cora,transductive,gcn,gcn_352_128.pt,0.32,1.0,0.0,0.0 +Cora,transductive,gcn,gcn_288_128.pt,0.298,1.0,0.0,0.0 +Cora,transductive,gcn,gcn_224_128.pt,0.211,1.0,0.0,0.0 +Cora,transductive,gat,gat_288_128.pt,0.101,1.0,0.0,0.0 +Cora,transductive,gat,gat_352_128.pt,0.146,1.0,0.0,0.0 +Cora,transductive,gat,gat_224_128.pt,0.257,1.0,0.0,0.0 +Cora,transductive,sage,sage_224_128.pt,0.071,1.0,0.0,0.0 +Cora,transductive,sage,sage_288_128.pt,0.07,1.0,0.0,0.0 +Cora,transductive,sage,sage_352_128.pt,0.091,1.0,0.0,0.0 diff --git a/experiments/results/target_accuracy_vs_mask_ratio.png b/experiments/results/target_accuracy_vs_mask_ratio.png new file mode 100644 index 0000000..775dfe0 Binary files /dev/null and b/experiments/results/target_accuracy_vs_mask_ratio.png differ diff --git a/experiments/results/ver_acc_gat_inductive.png b/experiments/results/ver_acc_gat_inductive.png new file mode 100644 index 0000000..09fee4a Binary files /dev/null and b/experiments/results/ver_acc_gat_inductive.png differ diff --git a/experiments/results/ver_acc_gat_transductive.png b/experiments/results/ver_acc_gat_transductive.png new file mode 100644 index 0000000..c7a909e Binary files /dev/null and b/experiments/results/ver_acc_gat_transductive.png differ diff --git a/experiments/results/ver_acc_gcn_inductive.png b/experiments/results/ver_acc_gcn_inductive.png new file mode 100644 index 0000000..680dbad Binary files /dev/null and b/experiments/results/ver_acc_gcn_inductive.png differ diff --git a/experiments/results/ver_acc_gcn_transductive.png b/experiments/results/ver_acc_gcn_transductive.png new file mode 100644 index 0000000..542b1fd Binary files /dev/null and b/experiments/results/ver_acc_gcn_transductive.png differ diff --git a/experiments/results/ver_acc_sage_inductive.png b/experiments/results/ver_acc_sage_inductive.png new file mode 100644 index 0000000..0b9101a Binary files /dev/null and b/experiments/results/ver_acc_sage_inductive.png differ diff --git a/experiments/results/ver_acc_sage_transductive.png b/experiments/results/ver_acc_sage_transductive.png new file mode 100644 index 0000000..bc1b1aa Binary files /dev/null and b/experiments/results/ver_acc_sage_transductive.png differ diff --git a/experiments/results/verification_accuracy_by_architecture.png b/experiments/results/verification_accuracy_by_architecture.png new file mode 100644 index 0000000..2fa4239 Binary files /dev/null and b/experiments/results/verification_accuracy_by_architecture.png differ diff --git a/experiments/results/verification_performance_by_setting.png b/experiments/results/verification_performance_by_setting.png new file mode 100644 index 0000000..be9b6bc Binary files /dev/null and b/experiments/results/verification_performance_by_setting.png differ diff --git a/experiments/scripts/__init__.py b/experiments/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/scripts/run_benign.py b/experiments/scripts/run_benign.py new file mode 100644 index 0000000..bcf6de6 --- /dev/null +++ b/experiments/scripts/run_benign.py @@ -0,0 +1,181 @@ +import torch +import copy +from src.utils.config import parse_args +import src.datasets.datareader +import src.models.gnn +import torch.nn.functional as F +from src.datasets.graph_operator import split_subgraph +from pathlib import Path +import os +import random + + +def transductive_train(args, model_save_path, graph_data, process): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if graph_data is None: + data = src.datasets.datareader.get_data(args) + gdata = src.datasets.datareader.GraphData(data, args) + else: + gdata = graph_data + + path = Path(model_save_path) + os.makedirs(path.parent, exist_ok=True) + + loss_fn = torch.nn.CrossEntropyLoss() + predict_fn = lambda output: output.max(1, keepdim=True)[1] + + # Load existing model or create new + if path.is_file(): + gnn_model = torch.load(model_save_path) + else: + if args.benign_model == 'gcn': + gnn_model = src.models.gnn.GCN(gdata.feat_dim, gdata.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'sage': + gnn_model = src.models.gnn.GraphSage(gdata.feat_dim, gdata.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'gat': + gnn_model = src.models.gnn.GAT(gdata.feat_dim, gdata.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'gin': + gnn_model = src.models.gnn.GIN(gdata.feat_dim, gdata.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'sgc': + gnn_model = src.models.gnn.SGC(gdata.feat_dim, gdata.class_num, hidden_dim=args.benign_hidden_dim) + + gnn_model.to(device) + optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.benign_lr) + last_train_acc = 0.0 + + if process == 'test': + train_nodes_index = [i for i in range(gdata.node_num)] + random.shuffle(train_nodes_index) + train_nodes_index = train_nodes_index[:len(gdata.target_nodes_index)] + + for epoch in range(args.benign_train_epochs): + gnn_model.train() + optimizer.zero_grad() + input_data = gdata.features.to(device), gdata.adjacency.to(device) + labels = gdata.labels.to(device) + _, output = gnn_model(input_data) + loss = loss_fn(output[train_nodes_index] if process=='test' else output[gdata.target_nodes_index], + labels[train_nodes_index] if process=='test' else labels[gdata.target_nodes_index]) + loss.backward() + optimizer.step() + + # Early stopping check every 50 epochs + if (epoch + 1) % 50 == 0: + gnn_model.eval() + _, output = gnn_model(input_data) + pred = predict_fn(output) + train_pred = pred[train_nodes_index] if process=='test' else pred[gdata.target_nodes_index] + train_labels = labels[train_nodes_index] if process=='test' else labels[gdata.target_nodes_index] + correct = (train_pred.squeeze() == train_labels).sum().item() + train_acc = correct / train_pred.shape[0] * 100 + + if last_train_acc == 0.0: + last_train_acc = train_acc + else: + if abs(train_acc - last_train_acc) / last_train_acc * 100 <= 0.5: + break + last_train_acc = train_acc + + torch.save(gnn_model, model_save_path) + + # Test accuracy + gnn_model.eval() + input_data = gdata.features.to(device), gdata.adjacency.to(device) + _, output = gnn_model(input_data) + pred = predict_fn(output) + test_pred = pred[gdata.test_nodes_index] + test_labels = gdata.labels[gdata.test_nodes_index] + test_acc = (test_pred.squeeze() == test_labels).sum().item() / test_pred.shape[0] * 100 + + return gdata, gnn_model, round(test_acc, 3) + + +def inductive_train(args, model_save_path, graph_data, process): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if graph_data is None: + data = src.datasets.datareader.get_data(args) + gdata = src.datasets.datareader.GraphData(data, args) + target_graph_data, shadow_graph_data, attacker_graph_data, test_graph_data = split_subgraph(gdata) + else: + target_graph_data, shadow_graph_data, attacker_graph_data, test_graph_data = graph_data + + path = Path(model_save_path) + os.makedirs(path.parent, exist_ok=True) + loss_fn = torch.nn.CrossEntropyLoss() + predict_fn = lambda output: output.max(1, keepdim=True)[1] + + if path.is_file(): + gnn_model = torch.load(model_save_path) + else: + if args.benign_model == 'gcn': + gnn_model = src.models.gnn.GCN(target_graph_data.feat_dim, target_graph_data.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'sage': + gnn_model = src.models.gnn.GraphSage(target_graph_data.feat_dim, target_graph_data.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'gat': + gnn_model = src.models.gnn.GAT(target_graph_data.feat_dim, target_graph_data.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'gin': + gnn_model = src.models.gnn.GIN(target_graph_data.feat_dim, target_graph_data.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'sgc': + gnn_model = src.models.gnn.SGC(target_graph_data.feat_dim, target_graph_data.class_num, hidden_dim=args.benign_hidden_dim) + + gnn_model.to(device) + optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.benign_lr) + last_train_acc = 0.0 + + for epoch in range(args.benign_train_epochs): + gnn_model.train() + optimizer.zero_grad() + input_data = target_graph_data.features.to(device), target_graph_data.adjacency.to(device) + labels = target_graph_data.labels.to(device) + _, output = gnn_model(input_data) + loss = loss_fn(output, labels) + loss.backward() + optimizer.step() + + if (epoch + 1) % 50 == 0: + gnn_model.eval() + _, output = gnn_model(input_data) + pred = predict_fn(output) + correct = (pred.squeeze() == labels).sum().item() + train_acc = correct / labels.shape[0] * 100 + + if last_train_acc == 0.0: + last_train_acc = train_acc + else: + if abs(train_acc - last_train_acc) / last_train_acc * 100 <= 0.5: + break + last_train_acc = train_acc + + torch.save(gnn_model, model_save_path) + + gnn_model.eval() + input_data = test_graph_data.features.to(device), test_graph_data.adjacency.to(device) + _, output = gnn_model(input_data) + pred = predict_fn(output) + correct = (pred.squeeze() == test_graph_data.labels).sum().item() + test_acc = correct / test_graph_data.labels.shape[0] * 100 + + return [target_graph_data, shadow_graph_data, attacker_graph_data, test_graph_data], gnn_model, round(test_acc, 3) + + +def run(args, model_save_path, given_graph_data=None, process=None): + if args.task_type == 'transductive': + return transductive_train(args, model_save_path, given_graph_data, process) + elif args.task_type == 'inductive': + return inductive_train(args, model_save_path, given_graph_data, process) + + +if __name__ == '__main__': + args = parse_args() + + # Example: auto-create dataset folders and save models where run_robustness expects + dataset = args.dataset + for task in ['transductive', 'inductive']: + folder = f"../temp_results/diff/model_states/{dataset}/{task}/extraction_models/random_mask/" + os.makedirs(folder, exist_ok=True) + model_file = os.path.join(folder, f"{args.benign_model}_model.pt") + print(f"Training and saving model to: {model_file}") + graph_data, gnn_model, test_acc = run(args, model_file, process='train') + print(f"{task.capitalize()} training finished. Test accuracy: {test_acc}") diff --git a/experiments/scripts/run_main.py b/experiments/scripts/run_main.py new file mode 100644 index 0000000..6510bd9 --- /dev/null +++ b/experiments/scripts/run_main.py @@ -0,0 +1,26 @@ +import sys +import os +import yaml + +# Add src folder to PYTHONPATH +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) + +from utils.config import parse_args +from verifier.verification_cfg import multiple_experiments + +if __name__ == '__main__': + args = parse_args() + + # Resolve absolute path to config folder + config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../config")) + + # Load global configuration + global_cfg_file = os.path.join(config_path, "global_cfg.yaml") + if not os.path.exists(global_cfg_file): + raise FileNotFoundError(f"Global config file not found: {global_cfg_file}") + + with open(global_cfg_file, 'r') as file: + global_cfg = yaml.safe_load(file) + + # Pass the absolute config path to the multiple_experiments function + multiple_experiments(args, global_cfg, config_path=config_path) diff --git a/experiments/scripts/run_robustness.py b/experiments/scripts/run_robustness.py new file mode 100644 index 0000000..3e3692c --- /dev/null +++ b/experiments/scripts/run_robustness.py @@ -0,0 +1,304 @@ +import sys, os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../src"))) +from utils.config import parse_args +import src.datasets.datareader +import src.datasets.graph_operator +from verifier.verification_cfg import multiple_experiments +import extraction +import yaml +import torch +import random +import math + +def fine_tune(args, load_root, specific_mask_mag): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + substring_path = load_root.split('/') + substring_path.remove('..') + substring_path.remove('') + substring_path.remove('temp_results') + + load_folder_root, save_folder_root = list(), list() + save_root = '../robustness_results/fine_tune' + for i in substring_path: + save_root = os.path.join(save_root, i) + with os.scandir(load_root) as itr_0: + for target_model_folder in itr_0: + sub_load_root = os.path.join(load_root, target_model_folder.name) + sub_save_root = os.path.join(save_root, target_model_folder.name) + with os.scandir(sub_load_root) as itr_1: + for mask_mag in itr_1: + if mask_mag.name != specific_mask_mag: + continue + final_load_root = os.path.join(sub_load_root, mask_mag.name) + final_save_root = os.path.join(sub_save_root, mask_mag.name) + load_folder_root.append(final_load_root) + save_folder_root.append(final_save_root) + + if not os.path.exists(final_save_root): + os.makedirs(final_save_root) + + data = src.datasets.datareader.get_data(args) + graph_data = src.datasets.datareader.GraphData(data, args) + if args.task_type == 'inductive': + _, _, graph_data, _ = src.datasets.graph_operator.split_subgraph(graph_data) + loss_fn = torch.nn.CrossEntropyLoss() + predict_fn = lambda output: output.max(1, keepdim=True)[1] + + for folder_index in range(len(load_folder_root)): + models_folder_path = load_folder_root[folder_index] + with os.scandir(models_folder_path) as itr: + for entry in itr: + if 'train' in entry.name: + continue + + original_model_load_path = os.path.join(models_folder_path, entry.name) + fine_tune_model_save_path = os.path.join(save_folder_root[folder_index], entry.name) + + gnn_model = torch.load(original_model_load_path) + gnn_model.to(device) + gnn_model.train() + optimizer = torch.optim.Adam(gnn_model.parameters(), lr=args.benign_lr) + + # use testing dataset to fine-tune extraction models + last_train_acc = 0.0 + if args.task_type == 'transductive': + for epoch in range(args.benign_train_epochs): + optimizer.zero_grad() + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + labels = graph_data.labels.to(device) + _, output = gnn_model(input_data) + loss = loss_fn(output[graph_data.attacker_nodes_index], labels[graph_data.attacker_nodes_index]) + loss.backward() + optimizer.step() + + train_correct_num = 0 + if (epoch + 1) % 100 == 0: + _, output = gnn_model(input_data) + pred = predict_fn(output) + train_pred = pred[graph_data.attacker_nodes_index] + train_labels = graph_data.labels[graph_data.attacker_nodes_index] + for i in range(train_pred.shape[0]): + if train_pred[i, 0] == train_labels[i]: + train_correct_num += 1 + train_acc = train_correct_num / train_pred.shape[0] * 100 + if last_train_acc == 0.0: + last_train_acc = train_acc + else: + train_acc_diff = (train_acc - last_train_acc) / last_train_acc * 100 + if train_acc_diff <= 0.5: #0.5% + break + else: + last_train_acc = train_acc + elif args.task_type == 'inductive': + for epoch in range(args.benign_train_epochs): + optimizer.zero_grad() + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + labels = graph_data.labels.to(device) + _, output = gnn_model(input_data) + loss = loss_fn(output, labels) + loss.backward() + optimizer.step() + + train_correct_num = 0 + if (epoch + 1) % 100 == 0: + _, output = gnn_model(input_data) + predictions = predict_fn(output) + + for i in range(predictions.shape[0]): + if predictions[i, 0] == labels[i]: + train_correct_num += 1 + train_acc = train_correct_num / predictions.shape[0] * 100 + + if last_train_acc == 0.0: + last_train_acc = train_acc + else: + train_acc_diff = (train_acc - last_train_acc) / last_train_acc * 100 + if train_acc_diff <= 0.5: #0.5% + break + else: + last_train_acc = train_acc + + torch.save(gnn_model, fine_tune_model_save_path) + + +def prune(args, load_root, specific_mask_mag): + substring_path = load_root.split('/') + substring_path.remove('..') + substring_path.remove('') + substring_path.remove('temp_results') + + load_folder_root, save_folder_root = list(), list() + save_root = '../robustness_results/prune' + save_root = os.path.join(save_root, str(args.prune_weight_ratio)) + + for i in substring_path: + save_root = os.path.join(save_root, i) + with os.scandir(load_root) as itr_0: + for target_model_folder in itr_0: + sub_load_root = os.path.join(load_root, target_model_folder.name) + sub_save_root = os.path.join(save_root, target_model_folder.name) + with os.scandir(sub_load_root) as itr_1: + for mask_mag in itr_1: + if mask_mag.name != specific_mask_mag: + continue + final_load_root = os.path.join(sub_load_root, mask_mag.name) + final_save_root = os.path.join(sub_save_root, mask_mag.name) + load_folder_root.append(final_load_root) + save_folder_root.append(final_save_root) + + if not os.path.exists(final_save_root): + os.makedirs(final_save_root) + + for folder_index in range(len(load_folder_root)): + models_folder_path = load_folder_root[folder_index] + with os.scandir(models_folder_path) as itr: + for entry in itr: + if 'train' in entry.name: + continue + original_model_load_path = os.path.join(models_folder_path, entry.name) + prune_model_save_path = os.path.join(save_folder_root[folder_index], entry.name) + + gnn_model = torch.load(original_model_load_path) + + for name, param in gnn_model.named_parameters(): + if 'fc' in name: + continue + if 'bias' in name: + continue + + original_param_shape = param.data.shape + temp_param = torch.flatten(param.data) + prune_num = math.floor(temp_param.shape[0] * args.prune_weight_ratio) + prune_index = [i for i in range(temp_param.shape[0])] + random.shuffle(prune_index) + prune_index = prune_index[:prune_num] + temp_param[prune_index] = 0 + param.data = temp_param.reshape(original_param_shape) + + torch.save(gnn_model, prune_model_save_path) + + +def double_extraction(args, load_root, specific_mask_mag): + substring_path = load_root.split('/') + substring_path.remove('..') + substring_path.remove('') + substring_path.remove('temp_results') + + load_folder_root, save_folder_root = list(), list() + save_root = '../robustness_results/double_extraction' + for i in substring_path: + save_root = os.path.join(save_root, i) + with os.scandir(load_root) as itr_0: + for target_model_folder in itr_0: + sub_load_root = os.path.join(load_root, target_model_folder.name) + sub_save_root = os.path.join(save_root, target_model_folder.name) + with os.scandir(sub_load_root) as itr_1: + for mask_mag in itr_1: + if mask_mag.name != specific_mask_mag: + continue + final_load_root = os.path.join(sub_load_root, mask_mag.name) + final_save_root = os.path.join(sub_save_root, mask_mag.name) + load_folder_root.append(final_load_root) + save_folder_root.append(final_save_root) + + if not os.path.exists(final_save_root): + os.makedirs(final_save_root) + + data = src.datasets.datareader.get_data(args) + graph_data = src.datasets.datareader.GraphData(data, args) + if args.task_type == 'inductive': + target_graph_data, shadow_graph_data, attacker_graph_data, test_graph_data = src.datasets.graph_operator.split_subgraph(graph_data) + graph_data = [target_graph_data, shadow_graph_data, attacker_graph_data, test_graph_data] + + for folder_index in range(len(load_folder_root)): + models_folder_path = load_folder_root[folder_index] + with os.scandir(models_folder_path) as itr: + for entry in itr: + if 'train' in entry.name: + continue + + original_model_load_path = os.path.join(models_folder_path, entry.name) + double_extraction_model_save_path = os.path.join(save_folder_root[folder_index], entry.name) + gnn_model = torch.load(original_model_load_path) + + arch_and_layers = entry.name.split('_') + arch_and_layers[-1] = arch_and_layers[-1].strip('.pt') + args.extraction_model = arch_and_layers[1] + layers = list() + for i in arch_and_layers[2:]: + layers.append(int(i)) + args.extraction_hidden_dim = layers + + _, _, _ = extraction.run(args, double_extraction_model_save_path, graph_data, gnn_model, 'test') + + +if __name__ == '__main__': + args = parse_args() + + # Get absolute path to the config folder + config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../config")) + + # Build the path to global_cfg.yaml + global_cfg_file = os.path.join(config_path, "global_cfg.yaml") + + if not os.path.exists(global_cfg_file): + raise FileNotFoundError(f"Global config file not found: {global_cfg_file}") + + with open(global_cfg_file, "r") as file: + global_cfg = yaml.safe_load(file) + + # transductive + args.task_type = 'transductive' + args.mask_feat_ratio = 0.1 + path = '../temp_results/diff/model_states/{}/transductive/extraction_models/random_mask/'.format(global_cfg['dataset']) + transductive_mask_mag = '1.0_{}'.format(args.mask_feat_ratio) + # fine_tune(args, path, transductive_mask_mag) + + global_cfg["test_save_root"] = "../robustness_results/fine_tune/diff/model_states/" + global_cfg["res_path"] = "../robustness_results/res/fine_tune" + # multiple_experiments(args, global_cfg) + + + for prune_ratio in [0.6, 0.7]: + args.prune_weight_ratio = prune_ratio + prune(args, path, transductive_mask_mag) + global_cfg["test_save_root"] = "../robustness_results/prune/{}/diff/model_states/".format(prune_ratio) + global_cfg["res_path"] = "../robustness_results/res/prune/{}".format(prune_ratio) + multiple_experiments(args, global_cfg) + + + # double_extraction(args, path, transductive_mask_mag) + global_cfg["test_save_root"] = "../robustness_results/double_extraction/diff/model_states/" + global_cfg["res_path"] = "../robustness_results/res/double_extraction" + # multiple_experiments(args, global_cfg) + + + # inductive + args.task_type = 'inductive' + args.mask_feat_ratio = 0.05 + path = '../temp_results/diff/model_states/{}/inductive/extraction_models/random_mask/'.format(global_cfg['dataset']) + inductive_mask_mag = '1.0_{}'.format(args.mask_feat_ratio) + # fine_tune(args, path, inductive_mask_mag) + + global_cfg["test_save_root"] = "../robustness_results/fine_tune/diff/model_states/" + global_cfg["res_path"] = "../robustness_results/res/fine_tune" + # multiple_experiments(args, global_cfg) + + + for prune_ratio in [0.6, 0.7]: + args.prune_weight_ratio = prune_ratio + prune(args, path, inductive_mask_mag) + global_cfg["test_save_root"] = "../robustness_results/prune/{}/diff/model_states/".format(prune_ratio) + global_cfg["res_path"] = "../robustness_results/res/prune/{}".format(prune_ratio) + multiple_experiments(args, global_cfg) + + + # double_extraction(args, path, inductive_mask_mag) + global_cfg["test_save_root"] = "../robustness_results/double_extraction/diff/model_states/" + global_cfg["res_path"] = "../robustness_results/res/double_extraction" + # multiple_experiments(args, global_cfg) + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 351816a..b19347b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,146 @@ -torch==2.3.0 -torch-geometric>=2.0.0 -numpy -scipy -networkx -scikit-learn -tqdm -pyyaml -pydantic -torchdata>=0.7.0,<0.8.0 \ No newline at end of file +aiohappyeyeballs==2.6.1 +aiohttp==3.12.15 +aiosignal==1.4.0 +anyio==4.10.0 +argon2-cffi==25.1.0 +argon2-cffi-bindings==25.1.0 +arrow==1.3.0 +asttokens==3.0.0 +async-lru==2.0.5 +attrs==25.3.0 +babel==2.17.0 +beautifulsoup4==4.13.5 +bleach==6.2.0 +certifi==2025.8.3 +cffi==1.17.1 +charset-normalizer==3.4.3 +comm==0.2.3 +contourpy==1.3.3 +cycler==0.12.1 +debugpy==1.8.16 +decorator==5.2.1 +defusedxml==0.7.1 +executing==2.2.1 +fastjsonschema==2.21.2 +filelock==3.19.1 +fonttools==4.59.2 +fqdn==1.5.1 +frozenlist==1.7.0 +fsspec==2025.9.0 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +idna==3.10 +ipykernel==6.30.1 +ipython==9.5.0 +ipython_pygments_lexers==1.1.1 +ipywidgets==8.1.7 +isoduration==20.11.0 +jedi==0.19.2 +Jinja2==3.1.6 +joblib==1.5.2 +json5==0.12.1 +jsonpointer==3.0.0 +jsonschema==4.25.1 +jsonschema-specifications==2025.4.1 +jupyter==1.1.1 +jupyter-console==6.6.3 +jupyter-events==0.12.0 +jupyter-lsp==2.3.0 +jupyter_client==8.6.3 +jupyter_core==5.8.1 +jupyter_server==2.17.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.4.7 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +jupyterlab_widgets==3.0.15 +kiwisolver==1.4.9 +lark==1.2.2 +MarkupSafe==3.0.2 +matplotlib==3.10.6 +matplotlib-inline==0.1.7 +mistune==3.1.4 +mpmath==1.3.0 +multidict==6.6.4 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.5 +notebook==7.4.5 +notebook_shim==0.2.4 +numpy==2.3.2 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +packaging==25.0 +pandas==2.3.2 +pandocfilters==1.5.1 +parso==0.8.5 +pexpect==4.9.0 +pillow==11.3.0 +platformdirs==4.4.0 +prometheus_client==0.22.1 +prompt_toolkit==3.0.52 +propcache==0.3.2 +psutil==7.0.0 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pycparser==2.22 +Pygments==2.19.2 +pyparsing==3.2.3 +python-dateutil==2.9.0.post0 +python-json-logger==3.3.0 +pytz==2025.2 +PyYAML==6.0.2 +pyzmq==27.0.2 +referencing==0.36.2 +requests==2.32.5 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rfc3987-syntax==1.1.0 +rpds-py==0.27.1 +scikit-learn==1.7.1 +scipy==1.16.1 +seaborn==0.13.2 +Send2Trash==1.8.3 +setuptools==80.9.0 +six==1.17.0 +sniffio==1.3.1 +soupsieve==2.8 +stack-data==0.6.3 +sympy==1.13.1 +terminado==0.18.1 +threadpoolctl==3.6.0 +tinycss2==1.4.0 +torch==2.5.1 +torch-geometric==2.6.1 +torchaudio==2.5.1 +torchvision==0.20.1 +tornado==6.5.2 +tqdm==4.67.1 +traitlets==5.14.3 +triton==3.1.0 +types-python-dateutil==2.9.0.20250822 +typing_extensions==4.15.0 +tzdata==2025.2 +uri-template==1.3.0 +urllib3==2.5.0 +wcwidth==0.2.13 +webcolors==24.11.1 +webencodings==0.5.1 +websocket-client==1.8.0 +widgetsnbextension==4.0.14 +yarl==1.20.1 diff --git a/src/src/__init__.py b/src/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/datasets/__init__.py b/src/src/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/datasets/datareader.py b/src/src/datasets/datareader.py new file mode 100644 index 0000000..c9b7e8d --- /dev/null +++ b/src/src/datasets/datareader.py @@ -0,0 +1,159 @@ +import os +import torch +import torch_geometric.datasets as dt +import random +import math +import copy + +def get_data(args): + if args.dataset == 'Citeseer' or args.dataset == 'DBLP' or args.dataset == 'PubMed': + dataset = dt.CitationFull(args.data_path, args.dataset) + data_path = args.data_path + '/' + (args.dataset).lower() + '/processed/data.pt' + elif args.dataset == 'Coauthor': + dataset = dt.Coauthor(args.data_path, 'Physics') + data_path = args.data_path + '/' + 'Physics' + '/processed/data.pt' + elif args.dataset == 'Amazon': + dataset = dt.Amazon(args.data_path, 'Photo') + data_path = args.data_path + '/' + 'Photo' + '/processed/data.pt' + elif args.dataset == 'Cora': + dataset = dt.Planetoid(args.data_path, args.dataset) + data_path = args.data_path + '/' + args.dataset + '/processed/data.pt' + + data = torch.load(data_path) + + return data + +def load_dataset(dataset_name, data_path='./data'): + """ + Wrapper function for compatibility with evaluate.py + """ + class Args: + def __init__(self, dataset, path): + self.dataset = dataset + self.data_path = path + self.dataset_random_seed = 42 + self.split_dataset_ratio = [0.4, 0.3, 0.2, 0.1] + + args = Args(dataset_name, data_path) + return get_data(args) + + +class GraphData(torch.utils.data.Dataset): + def __init__(self, data, args): + self.features = data[0]['x'] + self.adjacency = data[0]['edge_index'] + self.labels = data[0]['y'] + self.node_num = len(self.labels) + self.feat_dim = len(self.features[0]) + + self.set_adj_mat() + self.get_class_num() + self.split_dataset(args) + + def set_adj_mat(self): + self.adj_matrix = torch.zeros([self.node_num, self.node_num]) + for i in range(self.node_num): + source_node = self.adjacency[0][i] + target_node = self.adjacency[1][i] + self.adj_matrix[source_node, target_node] = 1 + self.adj_matrix[target_node, source_node] = 1 + + def get_class_num(self): + labels = self.labels.tolist() + labels = set(labels) + self.class_num = len(labels) + + def split_dataset(self, args): + all_nodes_index = list(i for i in range(self.node_num)) + self.target_nodes_index, self.shadow_nodes_index, self.attacker_nodes_index, self.test_nodes_index = list(), list(), list(), list() + + each_class_nodes_index = [list() for _ in range(self.class_num)] + for i in range(self.node_num): + each_class_nodes_index[self.labels[i]].append(i) + + node_index = [i for i in range(self.node_num)] + random.seed(args.dataset_random_seed) + random.shuffle(node_index) + target_nodes_size = math.floor(self.node_num * args.split_dataset_ratio[0]) + shadow_nodes_size = math.floor(self.node_num * args.split_dataset_ratio[1]) + attacker_nodes_size = math.floor(self.node_num * args.split_dataset_ratio[2]) + test_nodes_size = self.node_num - target_nodes_size - shadow_nodes_size - attacker_nodes_size + self.target_nodes_index += node_index[:target_nodes_size] + self.shadow_nodes_index += node_index[target_nodes_size:(target_nodes_size + shadow_nodes_size)] + self.attacker_nodes_index += node_index[(target_nodes_size + shadow_nodes_size):(target_nodes_size + shadow_nodes_size + attacker_nodes_size)] + self.test_nodes_index += node_index[(target_nodes_size + shadow_nodes_size + attacker_nodes_size):] + + self.target_nodes_index.sort() + self.shadow_nodes_index.sort() + self.attacker_nodes_index.sort() + self.test_nodes_index.sort() + + def __len__(self): + return self.node_num + + def __getitem__(self, index): + return 0 + + +class VarianceData(torch.utils.data.Dataset): + def __init__(self, label0_data_list, label1_data_list): + self.label0_data_list = copy.deepcopy(label0_data_list) + self.label1_data_list = copy.deepcopy(label1_data_list) + self.concat_data() + label0_data_label = [0 for _ in range(self.label0_data.shape[0])] + label1_data_label = [1 for _ in range(self.label1_data.shape[0])] + self.label = label0_data_label + label1_data_label + self.label = torch.as_tensor(self.label) + + def concat_data(self): + self.label0_data = None + self.label1_data = None + for data_index in range(len(self.label0_data_list)): + if data_index == 0: + self.label0_data = self.label0_data_list[data_index] + self.label1_data = self.label1_data_list[data_index] + else: + self.label0_data = torch.cat((self.label0_data, self.label0_data_list[data_index]), 0) + self.label1_data = torch.cat((self.label1_data, self.label1_data_list[data_index]), 0) + + self.data = torch.cat((self.label0_data, self.label1_data), 0) + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + return self.data[index, :], self.label[index] + + +class DistanceData(torch.utils.data.Dataset): + def __init__(self, label0_data_list, label1_data_list): + self.label0_data_list = copy.deepcopy(label0_data_list) + self.label1_data_list = copy.deepcopy(label1_data_list) + self.concat_data() + label0_data_label = [0 for _ in range(self.label0_data.shape[0])] + label1_data_label = [1 for _ in range(self.label1_data.shape[0])] + self.label = label0_data_label + label1_data_label + self.label = torch.as_tensor(self.label) + + def concat_data(self): + self.label0_data = None + self.label1_data = None + for data_index in range(len(self.label0_data_list)): + if data_index == 0: + self.label0_data = self.label0_data_list[data_index] + self.label1_data = self.label1_data_list[data_index] + else: + self.label0_data = torch.cat((self.label0_data, self.label0_data_list[data_index]), 0) + self.label1_data = torch.cat((self.label1_data, self.label1_data_list[data_index]), 0) + + self.data = torch.cat((self.label0_data, self.label1_data), 0) + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + return self.data[index, :], self.label[index] + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/src/src/datasets/graph_operator.py b/src/src/datasets/graph_operator.py new file mode 100644 index 0000000..6ac209c --- /dev/null +++ b/src/src/datasets/graph_operator.py @@ -0,0 +1,131 @@ +import copy +import torch +import torch.nn.functional as F +import torch.optim.lr_scheduler as lr_scheduler +import src.models.gnn as gnn +from torch_geometric.utils import add_self_loops, subgraph + + +class Graph_self(): + def __init__(self, features, edge_index, labels): + self.features = copy.deepcopy(features) + self.adjacency = copy.deepcopy(edge_index) + self.labels = copy.deepcopy(labels) + self.node_num = len(self.labels) + self.feat_dim = len(self.features[0]) + self.get_class_num() + self.set_adj_mat() + + def get_class_num(self): + labels = self.labels.tolist() + labels = set(labels) + self.class_num = len(labels) + + def set_adj_mat(self): + self.adj_matrix = torch.zeros([self.node_num, self.node_num]) + for i in range(self.node_num): + source_node = self.adjacency[0][i] + target_node = self.adjacency[1][i] + self.adj_matrix[source_node, target_node] = 1 + + def __len__(self): + return self.node_num + + def __getitem__(self, index): + return [self.features[index], self.adj_matrix[index], self.labels[index]] + + +def sort_features(args, feat_num, graph_data, original_model): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # Get the predictions of nodes from the original model + predict_fn = lambda output: output.max(1, keepdim=True)[1] + loss_fn = F.cross_entropy + + original_model.eval() + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + labels = graph_data.labels.to(device) + _, output = original_model(input_data) + original_predictions = predict_fn(output) + + chosen_feat = list() + candidate_feat = copy.deepcopy(graph_data.features) + + for iter in range(feat_num): + + feat_fidelity = dict() + for feat_index in range(graph_data.feat_dim): + if feat_index in chosen_feat: + continue + print(feat_index) + selection_model = None + if args.benign_model == 'gcn': + selection_model = gnn.GCN(iter+1, graph_data.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'sage': + selection_model = gnn.GraphSAGE(iter+1, graph_data.class_num, hidden_dim=args.benign_hidden_dim) + elif args.benign_model == 'gat': + selection_model = gnn.GAT(iter+1, graph_data.class_num, hidden_dim=args.benign_hidden_dim) + selection_model.to(device) + optimizer = torch.optim.Adam(selection_model.parameters(), lr=args.benign_lr, weight_decay=args.benign_weight_decay, betas=(0.5, 0.999)) + scheduler = lr_scheduler.MultiStepLR(optimizer, args.benign_lr_decay_steps, gamma=0.1) + + this_loop_feat = copy.deepcopy(chosen_feat) + this_loop_feat.append(feat_index) + selected_feat = candidate_feat[:, this_loop_feat] + this_loop_input_data = selected_feat.to(device), graph_data.adjacency.to(device) + this_loop_labels = graph_data.labels.to(device) + selection_model.train() + for epoch in range(args.benign_train_epochs): + optimizer.zero_grad() + _, output = selection_model(this_loop_input_data) + loss = loss_fn(output[graph_data.benign_train_mask], this_loop_labels[graph_data.benign_train_mask]) + loss.backward() + optimizer.step() + scheduler.step() + + selection_model.eval() + _, output = selection_model(this_loop_input_data) + pred = predict_fn(output) + final_pred = pred[graph_data.benign_train_mask] + original_pred = original_predictions[graph_data.benign_train_mask] + correct_num = 0 + for i in range(final_pred.shape[0]): + if final_pred[i, 0] == original_pred[i, 0]: + correct_num += 1 + test_acc = correct_num / final_pred.shape[0] * 100 + feat_fidelity.update({feat_index: test_acc}) + + feat_fidelity = sorted(feat_fidelity.items(), key=lambda x:x[1], reverse=True) + most_important_feat = feat_fidelity[0][0] + chosen_feat.append(most_important_feat) + + print(chosen_feat) + return chosen_feat + + +def split_subgraph(graph): + temp_edge_index = add_self_loops(graph.adjacency)[0] + target_edge_index = subgraph(torch.as_tensor(graph.target_nodes_index), temp_edge_index, relabel_nodes=True)[0] + shadow_edge_index = subgraph(torch.as_tensor(graph.shadow_nodes_index), temp_edge_index, relabel_nodes=True)[0] + attacker_edge_index = subgraph(torch.as_tensor(graph.attacker_nodes_index), temp_edge_index, relabel_nodes=True)[0] + test_edge_index = subgraph(torch.as_tensor(graph.test_nodes_index), temp_edge_index, relabel_nodes=True)[0] + + target_features = graph.features[graph.target_nodes_index] + shadow_features = graph.features[graph.shadow_nodes_index] + attacker_features = graph.features[graph.attacker_nodes_index] + test_features = graph.features[graph.test_nodes_index] + + target_labels = graph.labels[graph.target_nodes_index] + shadow_labels = graph.labels[graph.shadow_nodes_index] + attacker_labels = graph.labels[graph.attacker_nodes_index] + test_labels = graph.labels[graph.test_nodes_index] + + target_subgraph = Graph_self(target_features, target_edge_index, target_labels) + shadow_subgraph = Graph_self(shadow_features, shadow_edge_index, shadow_labels) + attacker_subgraph = Graph_self(attacker_features, attacker_edge_index, attacker_labels) + test_subgraph = Graph_self(test_features, test_edge_index, test_labels) + + return target_subgraph, shadow_subgraph, attacker_subgraph, test_subgraph \ No newline at end of file diff --git a/src/src/extraction/__init__.py b/src/src/extraction/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/extraction/extraction_runner.py b/src/src/extraction/extraction_runner.py new file mode 100644 index 0000000..243d770 --- /dev/null +++ b/src/src/extraction/extraction_runner.py @@ -0,0 +1,404 @@ +import torch +import random +import math +import copy +import torch.nn.functional as F +import torch.optim.lr_scheduler as lr_scheduler +import models.extraction +from tqdm import tqdm +from pathlib import Path + + +class Classification(torch.nn.Module): + + def __init__(self, emb_size, num_classes): + super(Classification, self).__init__() + + self.fc1 = torch.nn.Linear(emb_size, 256) + self.fc2 = torch.nn.Linear(256, num_classes) + + def forward(self, x): + x = F.relu(self.fc1(x)) + return F.log_softmax(self.fc2(x), dim=1) + + +def extract_outputs(graph_data, specific_nodes, independent_model, surrogate_model): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + independent_model.eval() + surrogate_model.eval() + + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + independent_embedding, independent_logits = independent_model(input_data) + surrogate_embedding, surrogate_logits = surrogate_model(input_data) + + softmax = torch.nn.Softmax(dim=1) + independent_prob = softmax(independent_logits) + surrogate_prob = softmax(surrogate_logits) + + if specific_nodes != None: + independent_prob = independent_prob[specific_nodes].cpu() + surrogate_prob = surrogate_prob[specific_nodes].cpu() + independent_embedding = independent_embedding[specific_nodes].cpu() + surrogate_embedding = surrogate_embedding[specific_nodes].cpu() + independent_logits = independent_logits[specific_nodes].cpu() + surrogate_logits = surrogate_logits[specific_nodes].cpu() + + probability = {'independent': independent_prob, 'surrogate': surrogate_prob} + embedding = {'independent': independent_embedding, 'surrogate': surrogate_embedding} + logits = {'independent': independent_logits, 'surrogate': surrogate_logits} + + return probability, logits, embedding + + +def verify(suspicious_logits, verifier_model): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + distance = torch.flatten(suspicious_logits).view(1, -1) + + verifier_model.to(device) + verifier_model.eval() + + outputs = verifier_model(distance.to(device)) + + return outputs + + +def evaluate_target_response(args, graph_data, model, response, process): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + model.eval() + model = model.to(device) + + if args.task_type == 'transductive': + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + embedding, output = model(input_data) + embedding = embedding.detach() + output = output.detach() + + if process == 'train': + search_nodes_index = graph_data.shadow_nodes_index + elif process == 'test': + search_nodes_index = graph_data.attacker_nodes_index + + if response == 'train_embeddings': + target_response = embedding[search_nodes_index] + elif response == 'train_outputs': + target_response = output[search_nodes_index] + elif response == 'test_embeddings': + target_response = embedding[graph_data.test_nodes_index] + elif response == 'test_outputs': + target_response = output[graph_data.test_nodes_index] + elif args.task_type == 'inductive': + if process == 'train': + extraction_input_data = graph_data[1].features.to(device), graph_data[1].adjacency.to(device) + extraction_embedding, extraction_output = model(extraction_input_data) + extraction_embedding = extraction_embedding.detach() + extraction_output = extraction_output.detach() + elif process == 'test': + extraction_input_data = graph_data[2].features.to(device), graph_data[2].adjacency.to(device) + extraction_embedding, extraction_output = model(extraction_input_data) + extraction_embedding = extraction_embedding.detach() + extraction_output = extraction_output.detach() + + test_input_data = graph_data[3].features.to(device), graph_data[3].adjacency.to(device) + test_embedding, test_output = model(test_input_data) + test_embedding = test_embedding.detach() + test_output = test_output.detach() + + if response == 'train_embeddings': + target_response = extraction_embedding + elif response == 'train_outputs': + target_response = extraction_output + elif response == 'test_embeddings': + target_response = test_embedding + elif response == 'test_outputs': + target_response = test_output + + + return target_response + + +def train_extraction_model(args, model_save_path, data, process, classifier): + clf_save_path = model_save_path + '_clf.pt' + graph_data, train_emb, train_outputs, test_outputs = data + softmax = torch.nn.Softmax(dim=1) + train_outputs = softmax(train_outputs) + + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + # prepare model + if args.task_type == 'transductive': + in_dim = graph_data.feat_dim + elif args.task_type == 'inductive': + in_dim = graph_data[1].feat_dim + + if args.extraction_method == 'white_box': + out_dim = train_emb.shape[1] + elif args.extraction_method == 'black_box': + out_dim = train_outputs.shape[1] + + if args.extraction_model == 'gcn': + extraction_model = models.extraction.GcnExtract(in_dim, out_dim, hidden_dim=args.extraction_hidden_dim) + elif args.extraction_model == 'sage': + extraction_model = models.extraction.SageExtract(in_dim, out_dim, hidden_dim=args.extraction_hidden_dim) + elif args.extraction_model == 'gat': + extraction_model = models.extraction.GatExtract(in_dim, out_dim, hidden_dim=args.extraction_hidden_dim) + elif args.extraction_model == 'gin': + extraction_model = models.extraction.GinExtract(in_dim, out_dim, hidden_dim=args.extraction_hidden_dim) + elif args.extraction_model == 'sgc': + extraction_model = models.extraction.SGCExtract(in_dim, out_dim, hidden_dim=args.extraction_hidden_dim) + + extraction_model = extraction_model.to(device) + + loss_fn = torch.nn.CrossEntropyLoss() + + optimizer_medium = torch.optim.Adam(extraction_model.parameters(), lr=args.extraction_lr) + + clf = None + if args.extraction_method == 'white_box': + if args.task_type == "inductive": + clf = Classification(out_dim, graph_data[0].class_num) + else: + clf = Classification(out_dim, graph_data.class_num) + clf = clf.to(device) + optimizer_classification = torch.optim.SGD(clf.parameters(), lr=args.extraction_lr) + elif args.extraction_method == 'black_box': + clf = None + predict_fn = lambda output: output.max(1, keepdim=True)[1] + + # train extraction model + last_train_acc, last_train_fide = 0.0, 0.0 + if args.task_type == 'transductive': + path = Path(model_save_path) + if path.is_file(): + extraction_model = torch.load(model_save_path) + if args.extraction_method == 'white_box': + clf = torch.load(clf_save_path) + else: + if process == 'train': + search_nodes_index = graph_data.shadow_nodes_index + elif process == 'test': + search_nodes_index = graph_data.attacker_nodes_index + + for epoch in range(args.extraction_train_epochs): + extraction_model.train() + if args.extraction_method == 'white_box': + clf.train() + train_emb = train_emb.to(device) + train_outputs = train_outputs.to(device) + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + extraction_embeddings, extraction_outputs = extraction_model(input_data) + part_embeddings = extraction_embeddings[search_nodes_index] + part_outputs = extraction_outputs[search_nodes_index] + + if args.extraction_method == 'white_box': + optimizer_medium.zero_grad() + optimizer_classification.zero_grad() + loss_emb = torch.sqrt(loss_fn(part_embeddings, train_emb)) + loss_emb.backward() + optimizer_medium.step() + + outputs = clf(part_embeddings.detach()) + train_labels = predict_fn(train_outputs) + train_labels = torch.flatten(train_labels) + loss_out = loss_fn(outputs, train_labels) + loss_out.backward() + optimizer_classification.step() + elif args.extraction_method == 'black_box': + optimizer_medium.zero_grad() + loss = loss_fn(part_outputs, train_outputs) + if process == 'test' and classifier != None: + surrogate_outputs, _, _ = extract_outputs(graph_data, graph_data.target_nodes_index, extraction_model, extraction_model) + classify_logits = verify(surrogate_outputs["surrogate"], classifier) + classify_logits = torch.flatten(classify_logits) + evade_loss = loss_fn(classify_logits, torch.tensor(0).to(device)) + loss += 10 * evade_loss + + loss.backward() + optimizer_medium.step() + + if (epoch + 1) % 50 == 0: + extraction_model.eval() + if args.extraction_method == 'white_box': + clf.eval() + + acc_correct = 0 + fide_correct = 0 + + embeddings, outputs = extraction_model(input_data) + if args.extraction_method == 'white_box': + outputs = clf(embeddings.detach()) + pred = predict_fn(outputs) + train_labels = predict_fn(train_outputs) + + for i in range(len(search_nodes_index)): + if pred[search_nodes_index[i]] == graph_data.labels[search_nodes_index[i]]: + acc_correct += 1 + if pred[search_nodes_index[i]] == train_labels[i]: + fide_correct += 1 + + accuracy = acc_correct * 100.0 / len(search_nodes_index) + fidelity = fide_correct * 100.0 / train_outputs.shape[0] + if last_train_acc == 0.0 or last_train_fide == 0.0: + last_train_acc = accuracy + last_train_fide = fidelity + else: + train_acc_diff = (accuracy - last_train_acc) / last_train_acc * 100 + train_fide_diff = (fidelity - last_train_fide) / last_train_fide * 100 + if train_acc_diff <= 0.5 and train_fide_diff <= 0.5: # 0.5% + break + else: + last_train_acc = accuracy + last_train_fide = fidelity + + torch.save(extraction_model, model_save_path) + if args.extraction_method == 'white_box': + torch.save(clf, clf_save_path) + + extraction_model.eval() + if args.extraction_method == 'white_box': + clf.eval() + acc_correct, fide_correct = 0, 0 + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + embeddings, outputs = extraction_model(input_data) + if args.extraction_method == 'white_box': + outputs = clf(embeddings.detach()) + pred = predict_fn(outputs) + test_labels = predict_fn(test_outputs) + for i in range(len(graph_data.test_nodes_index)): + if pred[graph_data.test_nodes_index[i]] == graph_data.labels[graph_data.test_nodes_index[i]]: + acc_correct += 1 + if pred[graph_data.test_nodes_index[i]] == test_labels[i]: + fide_correct += 1 + accuracy = acc_correct * 100.0 / len(graph_data.test_nodes_index) + fidelity = fide_correct * 100.0 / test_outputs.shape[0] + save_acc = round(accuracy, 3) + save_fide = round(fidelity, 3) + elif args.task_type == 'inductive': + path = Path(model_save_path) + if path.is_file(): + extraction_model = torch.load(model_save_path) + if args.extraction_method == 'white_box': + clf = torch.load(clf_save_path) + else: + if process == 'train': + using_graph_data = graph_data[1] + elif process == 'test': + using_graph_data = graph_data[2] + + for epoch in range(args.extraction_train_epochs): + extraction_model.train() + if args.extraction_method == 'white_box': + clf.train() + train_emb = train_emb.to(device) + train_outputs = train_outputs.to(device) + + input_data = using_graph_data.features.to(device), using_graph_data.adjacency.to(device) + extraction_embeddings, extraction_outputs = extraction_model(input_data) + + if args.extraction_method == 'white_box': + optimizer_medium.zero_grad() + optimizer_classification.zero_grad() + loss_emb = torch.sqrt(loss_fn(extraction_embeddings, train_emb)) + loss_emb.backward() + optimizer_medium.step() + + outputs = clf(extraction_embeddings.detach()) + train_labels = predict_fn(train_outputs) + train_labels = torch.flatten(train_labels) + loss_out = loss_fn(outputs, train_labels) + loss_out.backward() + optimizer_classification.step() + elif args.extraction_method == 'black_box': + optimizer_medium.zero_grad() + loss = loss_fn(extraction_outputs, train_outputs) + loss.backward() + optimizer_medium.step() + + if (epoch + 1) % 50 == 0: + extraction_model.eval() + if args.extraction_method == 'white_box': + clf.eval() + + acc_correct = 0 + fide_correct = 0 + + embeddings, outputs = extraction_model(input_data) + if args.extraction_method == 'white_box': + outputs = clf(embeddings.detach()) + pred = predict_fn(outputs) + train_labels = predict_fn(train_outputs) + + for i in range(using_graph_data.node_num): + if pred[i] == using_graph_data.labels[i]: + acc_correct += 1 + if pred[i] == train_labels[i]: + fide_correct += 1 + + accuracy = acc_correct * 100.0 / using_graph_data.node_num + fidelity = fide_correct * 100.0 / train_outputs.shape[0] + if last_train_acc == 0.0 or last_train_fide == 0.0: + last_train_acc = accuracy + last_train_fide = fidelity + else: + train_acc_diff = (accuracy - last_train_acc) / last_train_acc * 100 + train_fide_diff = (fidelity - last_train_fide) / last_train_fide * 100 + if train_acc_diff <= 0.5 and train_fide_diff <= 0.5: # 0.5% + break + else: + last_train_acc = accuracy + last_train_fide = fidelity + + torch.save(extraction_model, model_save_path) + if args.extraction_method == 'white_box': + torch.save(clf, clf_save_path) + + extraction_model.eval() + if args.extraction_method == 'white_box': + clf.eval() + acc_correct, fide_correct = 0, 0 + input_data = graph_data[3].features.to(device), graph_data[3].adjacency.to(device) + embeddings, outputs = extraction_model(input_data) + if args.extraction_method == 'white_box': + outputs = clf(embeddings.detach()) + pred = predict_fn(outputs) + test_labels = predict_fn(test_outputs) + for i in range(graph_data[3].node_num): + if pred[i] == graph_data[3].labels[i]: + acc_correct += 1 + if pred[i] == test_labels[i]: + fide_correct += 1 + accuracy = acc_correct * 100.0 / graph_data[3].node_num + fidelity = fide_correct * 100.0 / test_outputs.shape[0] + save_acc = round(accuracy, 3) + save_fide = round(fidelity, 3) + + return extraction_model, clf, save_acc, save_fide + + +def run(args, model_save_path, graph_data, original_model, process, classifier): + train_emb = evaluate_target_response(args, graph_data, original_model, 'train_embeddings', process) # we do not use this in black-box extraction setting + train_outputs = evaluate_target_response(args, graph_data, original_model, 'train_outputs', process) + test_outputs = evaluate_target_response(args, graph_data, original_model, 'test_outputs', process) + extraction_data = graph_data, train_emb, train_outputs, test_outputs + extraction_model, _, extraction_acc, extraction_fide = train_extraction_model(args, model_save_path, extraction_data, process, classifier) + + return extraction_model, extraction_acc, extraction_fide + + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/src/src/fingerprinting/__init__.py b/src/src/fingerprinting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/masking/__init__.py b/src/src/masking/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/masking/boundary.py b/src/src/masking/boundary.py new file mode 100644 index 0000000..b3d5265 --- /dev/null +++ b/src/src/masking/boundary.py @@ -0,0 +1,191 @@ +import os +import src.datasets.graph_operator +import random +import copy +import torch +import torch.nn.functional as F +import torch_geometric.nn as nn +import src.models.gnn as gnn +import torch.optim.lr_scheduler as lr_scheduler +from sklearn.ensemble import RandomForestClassifier +from pathlib import Path +import pickle +import math + + +def mask_graph_data(args, graph_data, model): + mask_nodes = find_mask_nodes(args, graph_data, model) + mask_feat_num = math.floor(graph_data.x.size(1) * args.mask_feat_ratio) # x instead of features + + new_graph_data = copy.deepcopy(graph_data) + if args.mask_node_ratio == 0 or args.mask_feat_ratio == 0: + pass + else: + if args.mask_feat_type == 'random_mask': + mask_features = list(i for i in range(graph_data.x.size(1))) + random.seed(args.feature_random_seed) + random.shuffle(mask_features) + mask_features = mask_features[:mask_feat_num] + elif args.mask_feat_type == 'mask_by_dataset': + mask_features = find_mask_features_overall(args, graph_data, mask_feat_num) + else: + raise ValueError('Invalid mask method') + + for node_class in mask_nodes: + for node_index in node_class: + for i in range(mask_feat_num): + if args.mask_method == "flip": + new_graph_data.x[node_index][mask_features[i]] = ( + new_graph_data.x[node_index][mask_features[i]] + 1 + ) % 2 + elif args.mask_method == "fix": + new_graph_data.x[node_index][mask_features[i]] = 0 + + return new_graph_data, mask_nodes + + +def measure_posteriors(args, graph_data, measure_node_class, measure_model): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + measure_model.to(device) + measure_model.eval() + + if args.task_type == 'transductive': + input_data = graph_data.x.to(device), graph_data.edge_index.to(device) + elif args.task_type == 'inductive': + input_data = graph_data[0].x.to(device), graph_data[0].edge_index.to(device) + + _, outputs = measure_model(input_data) + + measure_nodes = [] + for node_class in measure_node_class: + measure_nodes += node_class + + node_posteriors = outputs[measure_nodes] + softmax = torch.nn.Softmax(dim=1) + node_posteriors = softmax(node_posteriors).detach() + + posterior_var = torch.var(node_posteriors, dim=1) + var_mean = torch.mean(posterior_var) + print(var_mean) + + +def find_mask_nodes(args, graph_data, model): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + model.eval() + input_data = graph_data.x.to(device), graph_data.edge_index.to(device) + _, output = model(input_data) + softmax = torch.nn.Softmax(dim=1) + possibility = softmax(output) + + if args.mask_node_type == 'each_class': + node_possibilities = [dict() for _ in range(graph_data.y.max().item() + 1)] + + if args.task_type == 'transductive': + each_class_num = [0 for _ in range(graph_data.y.max().item() + 1)] + for node_index in graph_data.train_mask.nonzero(as_tuple=True)[0]: + each_class_num[graph_data.y[node_index]] += 1 + each_class_mask_node_num = [math.floor(num * args.mask_node_ratio) for num in each_class_num] + + for node_index in graph_data.train_mask.nonzero(as_tuple=True)[0]: + node_poss = possibility[node_index].detach() + sorted_node_poss, _ = torch.sort(node_poss, descending=True) + node_class_distance = sorted_node_poss[0] - sorted_node_poss[1] + node_possibilities[graph_data.y[node_index].item()].update({node_index.item(): node_class_distance.item()}) + + elif args.task_type == 'inductive': + each_class_num = [0 for _ in range(graph_data.y.max().item() + 1)] + for node_index in range(graph_data.num_nodes): + each_class_num[graph_data.y[node_index]] += 1 + each_class_mask_node_num = [math.floor(num * args.mask_node_ratio) for num in each_class_num] + + for node_index in range(graph_data.num_nodes): + node_poss = possibility[node_index].detach() + sorted_node_poss, _ = torch.sort(node_poss, descending=True) + node_class_distance = sorted_node_poss[0] - sorted_node_poss[1] + node_possibilities[graph_data.y[node_index].item()].update({node_index: node_class_distance.item()}) + + new_node_possibilities = [ + dict(sorted(class_node_possibility.items(), key=lambda x: x[1], reverse=False)) + for class_node_possibility in node_possibilities + ] + + topk_nodes = [list(new_node_possibilities[i].keys())[:each_class_mask_node_num[i]] for i in range(len(new_node_possibilities))] + + elif args.mask_node_type == 'overall': + mask_node_num = math.floor(graph_data.num_nodes * args.mask_node_ratio) + + node_possibilities = dict() + if args.task_type == 'transductive': + for node_index in graph_data.train_mask.nonzero(as_tuple=True)[0]: + node_poss = possibility[node_index].detach() + sorted_node_poss, _ = torch.sort(node_poss, descending=True) + node_class_distance = sorted_node_poss[0] - sorted_node_poss[1] + node_possibilities.update({node_index.item(): node_class_distance.item()}) + elif args.task_type == 'inductive': + for node_index in range(graph_data.num_nodes): + node_poss = possibility[node_index].detach() + sorted_node_poss, _ = torch.sort(node_poss, descending=True) + node_class_distance = sorted_node_poss[0] - sorted_node_poss[1] + node_possibilities.update({node_index: node_class_distance.item()}) + + node_possibilities = dict(sorted(node_possibilities.items(), key=lambda x: x[1], reverse=False)) + topk_nodes = [list(node_possibilities.keys())[:mask_node_num]] + + return topk_nodes + + +def find_mask_features_overall(args, graph_data, feat_num): + if args.task_type == 'transductive': + X = graph_data.x[graph_data.train_mask].cpu().numpy() + Y = graph_data.y[graph_data.train_mask].cpu().numpy() + elif args.task_type == 'inductive': + X = graph_data.x.cpu().numpy() + Y = graph_data.y.cpu().numpy() + + dt_model = RandomForestClassifier(random_state=args.feature_random_seed) + dt_model.fit(X, Y) + feat_importance = dt_model.feature_importances_ + + importance_dict = {index: value for index, value in enumerate(feat_importance)} + importance_dict = dict(sorted(importance_dict.items(), key=lambda x: x[1], reverse=True)) + topk_features = list(importance_dict.keys())[:feat_num] + + return topk_features + + +def find_mask_features_individual(args, graph_data, gnn_model): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + softmax = torch.nn.Softmax(dim=1) + gnn_model.eval() + gnn_model.to(device) + + input_data = graph_data.x.to(device), graph_data.edge_index.to(device) + _, output = gnn_model(input_data) + possibility = softmax(output).detach() + var = torch.var(possibility, axis=1) + + if args.task_type == 'transductive': + search_node_list = graph_data.train_mask.nonzero(as_tuple=True)[0].tolist() + elif args.task_type == 'inductive': + search_node_list = list(range(graph_data.num_nodes)) + + original_variances = {node_index: var[node_index] for node_index in search_node_list} + + node_feat_importance = dict() + for node_index in search_node_list: + feat_var_diff = dict() + for feat_index in range(graph_data.x.size(1)): + temp_features = copy.deepcopy(graph_data.x) + temp_features[node_index, feat_index] = (temp_features[node_index, feat_index] + 1) % 2 + input_data = temp_features.to(device), graph_data.edge_index.to(device) + _, output = gnn_model(input_data) + possibility = softmax(output).detach() + temp_var = torch.var(possibility[node_index]) + var_diff = original_variances[node_index] - temp_var + feat_var_diff.update({feat_index: var_diff}) + feat_var_diff = dict(sorted(feat_var_diff.items(), key=lambda x: x[1], reverse=True)) + node_feat_importance.update({node_index: list(feat_var_diff.keys())}) + return node_feat_importance diff --git a/src/src/models/__init__.py b/src/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/models/extraction.py b/src/src/models/extraction.py new file mode 100644 index 0000000..b8fee5c --- /dev/null +++ b/src/src/models/extraction.py @@ -0,0 +1,124 @@ +import torch +import torch_geometric.nn as nn +import torch.nn.functional as F +from torch.nn import Linear, Sequential, ReLU, BatchNorm1d + +class SageExtract(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(SageExtract, self).__init__() + self.layers = torch.nn.ModuleList() + + self.layers.append(nn.SAGEConv(in_dim, hidden_dim[0])) + for i in range(len(hidden_dim) - 1): + self.layers.append(nn.SAGEConv(hidden_dim[i], hidden_dim[i+1])) + + self.fc = nn.Linear(hidden_dim[-1], out_dim) + #self.project_layer = nn.Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = layer(x, edge_index) + x = F.relu(x) + #x = self.project_layer(x) + embedding = x + x = self.fc(x) + + return embedding, x + + +class GcnExtract(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(GcnExtract, self).__init__() + self.layers = torch.nn.ModuleList() + + self.layers.append(nn.GCNConv(in_dim, hidden_dim[0])) + for i in range(len(hidden_dim) - 1): + self.layers.append(nn.GCNConv(hidden_dim[i], hidden_dim[i+1])) + + self.fc = nn.Linear(hidden_dim[-1], out_dim) + #self.project_layer = nn.Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = layer(x, edge_index) + x = F.relu(x) + #x = self.project_layer(x) + embedding = x + x = self.fc(x) + + return embedding, x + + +class GatExtract(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(GatExtract, self).__init__() + self.layers = torch.nn.ModuleList() + + self.layers.append(nn.GATConv(in_dim, hidden_dim[0])) + for i in range(len(hidden_dim) - 1): + self.layers.append(nn.GATConv(hidden_dim[i], hidden_dim[i+1])) + + self.fc = nn.Linear(hidden_dim[-1], out_dim) + #self.project_layer = nn.Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = layer(x, edge_index) + x = F.relu(x) + #x = self.project_layer(x) + embedding = x + x = self.fc(x) + + return embedding, x + +class GinExtract(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(GinExtract, self).__init__() + self.layers = torch.nn.ModuleList() + + self.layers.append(nn.GINConv( + Sequential(Linear(in_dim, in_dim), BatchNorm1d(in_dim), ReLU(), + Linear(in_dim, hidden_dim[0]), ReLU()))) + for i in range(len(hidden_dim) - 1): + self.layers.append(nn.GINConv( + Sequential(Linear(hidden_dim[i], hidden_dim[i]), BatchNorm1d(hidden_dim[i]), ReLU(), + Linear(hidden_dim[i], hidden_dim[i+1]), ReLU()))) + + self.fc = nn.Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = layer(x, edge_index) + + embedding = x + x = self.fc(x) + + return embedding, x + + +class SGCExtract(torch.nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(SGCExtract, self).__init__() + self.layers = torch.nn.ModuleList() + + self.layers.append(nn.SGConv(in_dim, hidden_dim[0], K=2)) + for i in range(len(hidden_dim) - 1): + self.layers.append(nn.SGConv(hidden_dim[i], hidden_dim[i+1], K=2)) + + self.fc = nn.Linear(hidden_dim[-1], out_dim) + #self.project_layer = nn.Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for layer in self.layers: + x = layer(x, edge_index) + x = F.relu(x) + #x = self.project_layer(x) + embedding = x + x = self.fc(x) + + return embedding, x \ No newline at end of file diff --git a/src/src/models/gnn.py b/src/src/models/gnn.py new file mode 100644 index 0000000..4b6c29a --- /dev/null +++ b/src/src/models/gnn.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Linear, Sequential, ReLU, BatchNorm1d +from torch_geometric.nn import ( + GCNConv, + GATConv, + SAGEConv, + GINConv, + SGConv +) + + +# ------------------- Base Helper ------------------- +def get_hidden_dims(hidden_dim): + """Ensure hidden_dim is a list.""" + if isinstance(hidden_dim, int): + return [hidden_dim] + elif isinstance(hidden_dim, list): + return hidden_dim + else: + raise ValueError("hidden_dim must be int or list of ints") + + +# ------------------- GCN ------------------- +class GCN(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(GCN, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + + self.layers.append(GCNConv(in_dim, hidden_dim[0])) + for i in range(len(hidden_dim) - 1): + self.layers.append(GCNConv(hidden_dim[i], hidden_dim[i + 1])) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for conv in self.layers: + x = F.relu(conv(x, edge_index)) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- GraphSAGE ------------------- +class GraphSage(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(GraphSage, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + + self.layers.append(SAGEConv(in_dim, hidden_dim[0])) + for i in range(len(hidden_dim) - 1): + self.layers.append(SAGEConv(hidden_dim[i], hidden_dim[i + 1])) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for conv in self.layers: + x = F.relu(conv(x, edge_index)) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- GAT ------------------- +class GAT(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32], heads=8): + super(GAT, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + + self.layers.append(GATConv(in_dim, hidden_dim[0], heads=heads, concat=True)) + for i in range(len(hidden_dim) - 1): + in_channels = hidden_dim[i] * heads if i == 0 else hidden_dim[i] + self.layers.append(GATConv(in_channels, hidden_dim[i + 1], heads=1, concat=False)) + + self.fc = Linear(hidden_dim[-1], out_dim) + self.heads = heads + + def forward(self, data): + x, edge_index = data + for i, conv in enumerate(self.layers): + x = F.relu(conv(x, edge_index)) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- GIN ------------------- +class GIN(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32]): + super(GIN, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + + self.layers.append(GINConv( + Sequential( + Linear(in_dim, in_dim), + BatchNorm1d(in_dim), + ReLU(), + Linear(in_dim, hidden_dim[0]), + ReLU() + ) + )) + + for i in range(len(hidden_dim) - 1): + self.layers.append(GINConv( + Sequential( + Linear(hidden_dim[i], hidden_dim[i]), + BatchNorm1d(hidden_dim[i]), + ReLU(), + Linear(hidden_dim[i], hidden_dim[i + 1]), + ReLU() + ) + )) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for conv in self.layers: + x = conv(x, edge_index) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- SGC ------------------- +class SGC(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32], K=2): + super(SGC, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + + self.layers.append(SGConv(in_dim, hidden_dim[0], K=K)) + for i in range(len(hidden_dim) - 1): + self.layers.append(SGConv(hidden_dim[i], hidden_dim[i + 1], K=K)) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for conv in self.layers: + x = F.relu(conv(x, edge_index)) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- Model Factory ------------------- +def get_gnn_model(model_name, in_dim, out_dim, hidden_dim=[64, 32], **kwargs): + model_name = model_name.lower() + if model_name == "gcn": + return GCN(in_dim, out_dim, hidden_dim) + elif model_name == "graphsage": + return GraphSage(in_dim, out_dim, hidden_dim) + elif model_name == "gat": + return GAT(in_dim, out_dim, hidden_dim, **kwargs) + elif model_name == "gin": + return GIN(in_dim, out_dim, hidden_dim) + elif model_name == "sgc": + return SGC(in_dim, out_dim, hidden_dim) + else: + raise ValueError(f"Unknown GNN model: {model_name}") diff --git a/src/src/models/gnn2.py b/src/src/models/gnn2.py new file mode 100644 index 0000000..7141475 --- /dev/null +++ b/src/src/models/gnn2.py @@ -0,0 +1,237 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Linear, Sequential, BatchNorm1d +from torch_geometric.nn import ( + GCNConv, + GATConv, + SAGEConv, + GINConv, + SGConv +) + + +# ------------------- Utility Functions ------------------- +def get_hidden_dims(hidden_dim): + """Ensure hidden_dim is a list.""" + if isinstance(hidden_dim, int): + return [hidden_dim] + elif isinstance(hidden_dim, list): + return hidden_dim + else: + raise ValueError("hidden_dim must be int or list of ints") + + +def get_activation(name="relu"): + """Return the chosen activation function.""" + name = name.lower() + if name == "relu": + return nn.ReLU() + elif name == "leakyrelu": + return nn.LeakyReLU(0.2) + elif name == "elu": + return nn.ELU() + elif name == "gelu": + return nn.GELU() + elif name == "sigmoid": + return nn.Sigmoid() + elif name == "tanh": + return nn.Tanh() + else: + raise ValueError(f"Unsupported activation: {name}") + + +# ------------------- GCN ------------------- +class GCN(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32], + dropout=0.5, activation="relu", use_bn=True): + super(GCN, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + self.bns = nn.ModuleList() if use_bn else None + self.act = get_activation(activation) + self.dropout = dropout + self.use_bn = use_bn + + # Build GCN layers + self.layers.append(GCNConv(in_dim, hidden_dim[0])) + for i in range(len(hidden_dim) - 1): + self.layers.append(GCNConv(hidden_dim[i], hidden_dim[i + 1])) + if use_bn: + for h in hidden_dim: + self.bns.append(BatchNorm1d(h)) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for i, conv in enumerate(self.layers): + x = conv(x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = self.act(x) + x = F.dropout(x, p=self.dropout, training=self.training) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- GraphSAGE ------------------- +class GraphSage(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32], + dropout=0.5, activation="relu", use_bn=True): + super(GraphSage, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + self.bns = nn.ModuleList() if use_bn else None + self.act = get_activation(activation) + self.dropout = dropout + self.use_bn = use_bn + + self.layers.append(SAGEConv(in_dim, hidden_dim[0])) + for i in range(len(hidden_dim) - 1): + self.layers.append(SAGEConv(hidden_dim[i], hidden_dim[i + 1])) + if use_bn: + for h in hidden_dim: + self.bns.append(BatchNorm1d(h)) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for i, conv in enumerate(self.layers): + x = conv(x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = self.act(x) + x = F.dropout(x, p=self.dropout, training=self.training) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- GAT ------------------- +class GAT(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32], heads=8, + dropout=0.5, activation="relu", use_bn=True): + super(GAT, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + self.bns = nn.ModuleList() if use_bn else None + self.act = get_activation(activation) + self.dropout = dropout + self.use_bn = use_bn + self.heads = heads + + self.layers.append(GATConv(in_dim, hidden_dim[0], heads=heads, concat=True)) + for i in range(len(hidden_dim) - 1): + in_channels = hidden_dim[i] * heads if i == 0 else hidden_dim[i] + self.layers.append(GATConv(in_channels, hidden_dim[i + 1], heads=1, concat=False)) + if use_bn: + for h in hidden_dim: + self.bns.append(BatchNorm1d(h)) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for i, conv in enumerate(self.layers): + x = conv(x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = self.act(x) + x = F.dropout(x, p=self.dropout, training=self.training) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- GIN ------------------- +class GIN(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32], + dropout=0.5, activation="relu", use_bn=True): + super(GIN, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + self.bns = nn.ModuleList() if use_bn else None + self.act = get_activation(activation) + self.dropout = dropout + self.use_bn = use_bn + + def mlp(in_dim, out_dim): + layers = [ + Linear(in_dim, out_dim), + self.act + ] + return Sequential(*layers) + + self.layers.append(GINConv(mlp(in_dim, hidden_dim[0]))) + for i in range(len(hidden_dim) - 1): + self.layers.append(GINConv(mlp(hidden_dim[i], hidden_dim[i + 1]))) + if use_bn: + for h in hidden_dim: + self.bns.append(BatchNorm1d(h)) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for i, conv in enumerate(self.layers): + x = conv(x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = F.dropout(x, p=self.dropout, training=self.training) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- SGC ------------------- +class SGC(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim=[64, 32], K=2, + dropout=0.5, activation="relu", use_bn=True): + super(SGC, self).__init__() + hidden_dim = get_hidden_dims(hidden_dim) + self.layers = nn.ModuleList() + self.bns = nn.ModuleList() if use_bn else None + self.act = get_activation(activation) + self.dropout = dropout + self.use_bn = use_bn + + self.layers.append(SGConv(in_dim, hidden_dim[0], K=K)) + for i in range(len(hidden_dim) - 1): + self.layers.append(SGConv(hidden_dim[i], hidden_dim[i + 1], K=K)) + if use_bn: + for h in hidden_dim: + self.bns.append(BatchNorm1d(h)) + + self.fc = Linear(hidden_dim[-1], out_dim) + + def forward(self, data): + x, edge_index = data + for i, conv in enumerate(self.layers): + x = conv(x, edge_index) + if self.use_bn: + x = self.bns[i](x) + x = self.act(x) + x = F.dropout(x, p=self.dropout, training=self.training) + embedding = x + x = self.fc(x) + return embedding, x + + +# ------------------- Model Factory ------------------- +def get_gnn_model(model_name, in_dim, out_dim, hidden_dim=[64, 32], **kwargs): + model_name = model_name.lower() + if model_name == "gcn": + return GCN(in_dim, out_dim, hidden_dim, **kwargs) + elif model_name == "graphsage": + return GraphSage(in_dim, out_dim, hidden_dim, **kwargs) + elif model_name == "gat": + return GAT(in_dim, out_dim, hidden_dim, **kwargs) + elif model_name == "gin": + return GIN(in_dim, out_dim, hidden_dim, **kwargs) + elif model_name == "sgc": + return SGC(in_dim, out_dim, hidden_dim, **kwargs) + else: + raise ValueError(f"Unknown GNN model: {model_name}") diff --git a/src/src/models/ownership_classifier.py b/src/src/models/ownership_classifier.py new file mode 100644 index 0000000..77e544e --- /dev/null +++ b/src/src/models/ownership_classifier.py @@ -0,0 +1,26 @@ +from torch import nn + +class mlp_nn(nn.Module): + def __init__(self, input_dim, hidden_layers, dropout=0.0): + super().__init__() + layers = list() + hidden_layer_num = len(hidden_layers) + + for i in range(hidden_layer_num): + if i == 0: + layers.append(nn.Linear(input_dim, hidden_layers[i])) + else: + layers.append(nn.Linear(hidden_layers[i-1], hidden_layers[i])) + + layers.append(nn.ReLU()) + + if (i+1) % 2 == 0 and i != (hidden_layer_num - 1) and dropout != 0.0: + layers.append(nn.Dropout(dropout)) + + layers.append(nn.Linear(hidden_layers[-1], 2)) + self.predict_layers = nn.Sequential(*layers) + + + def forward(self, input): + output = self.predict_layers(input) + return output \ No newline at end of file diff --git a/src/src/utils/__init__.py b/src/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/utils/config.py b/src/src/utils/config.py new file mode 100644 index 0000000..60fd789 --- /dev/null +++ b/src/src/utils/config.py @@ -0,0 +1,60 @@ +import argparse + +def add_data_group(group): + group.add_argument('--dataset', type=str, default='Cora', help="used dataset") + group.add_argument('--data_path', type=str, default='../dataset', help="the directory used to save dataset") + group.add_argument('--task_type', type=str, default='transductive') + group.add_argument('--dataset_random_seed', type=int, default=999) + group.add_argument('--feature_random_seed', type=int, default=999) + group.add_argument('--split_dataset_ratio', type=list, default=[0.3, 0.3, 0.3, 0.1]) + group.add_argument('--mask_node_ratio', type=float, default=1.0) + group.add_argument('--mask_feat_ratio', type=float, default=0.0) + group.add_argument('--mask_node_type', type=str, default='overall') + group.add_argument('--mask_feat_type', type=str, default='random_mask') + group.add_argument('--mask_method', type=str, default='flip') + group.add_argument('--prune_weight_ratio', type=float, default=0.1) + +def add_benign_model_group(group): + group.add_argument('--benign_model', type=str, default='gcn', help="used model") + group.add_argument('--benign_hidden_dim', nargs='+', default=[128, 64], type=int, help='hidden layers of the model') + group.add_argument('--benign_train_epochs', type=int, default=1000) + group.add_argument('--benign_lr', type=float, default=0.001) + group.add_argument('--antidistill_train_ratio', type=float, default=0.1) + group.add_argument('--benign_model_situation', type=str, default='load_if_exists') + + +def add_backdoor_model_group(group): + group.add_argument('--backdoor_train_node_ratio', type=float, default=0.15) + group.add_argument('--backdoor_test_node_ratio', type=float, default=0.1) + group.add_argument('--backdoor_feature_num', type=float, default=500) + group.add_argument('--backdoor_target_label', type=int, default=6) + group.add_argument('--backdoor_train_epochs', type=int, default=1000) + group.add_argument('--backdoor_lr', type=float, default=0.001) + group.add_argument('--backdoor_lr_decay_steps', nargs='+', default=[500, 800], type=int) + group.add_argument('--backdoor_weight_decay', type=float, default=5e-4) + + +def add_extraction_model_group(group): + group.add_argument('--extraction_model', type=str, default='gcn', help="used model") + group.add_argument('--extraction_hidden_dim', nargs='+', default=[64, 32], type=int, help='hidden layers of the model') + group.add_argument('--extraction_train_epochs', type=int, default=1000) + group.add_argument('--extraction_lr', type=float, default=0.001) + group.add_argument('--extraction_method', type=str, default='black_box') + group.add_argument('--extraction_ratio', type=float, default=0.5) + group.add_argument('--extraction_model_situation', type=str, default='load_if_exists') + group.add_argument('--double_extraction_model_situation', type=str, default='write_anyway') + + +def parse_args(): + parser = argparse.ArgumentParser() + data_group = parser.add_argument_group(title="Data-related configuration") + benign_model_group = parser.add_argument_group(title="Benign-model-related configuration") + backdoor_model_group = parser.add_argument_group(title="Attack-related configuration") + extraction_model_group = parser.add_argument_group(title="Extraction-model-related configuration") + + add_data_group(data_group) + add_benign_model_group(benign_model_group) + add_backdoor_model_group(backdoor_model_group) + add_extraction_model_group(extraction_model_group) + + return parser.parse_args() \ No newline at end of file diff --git a/src/src/verifier/__init__.py b/src/src/verifier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/src/verifier/verification_cfg.py b/src/src/verifier/verification_cfg.py new file mode 100644 index 0000000..3f4e90d --- /dev/null +++ b/src/src/verifier/verification_cfg.py @@ -0,0 +1,468 @@ +import json +import torch +import experiments.scripts.run_benign as benign +import extraction +import random +import math +import src.datasets.datareader +from src.extraction import extraction_runner +# import src.models.gnn +# import src.models.extraction +from torch.utils.data import DataLoader +from tqdm import tqdm +from models.ownership_classifier import mlp_nn +import src.masking.boundary as boundary +from statistics import mean +import time +import os +import yaml +import itertools +from datetime import timedelta +import copy + +def extract_outputs(graph_data, specific_nodes, independent_model, surrogate_model): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + independent_model.eval() + surrogate_model.eval() + + input_data = graph_data.features.to(device), graph_data.adjacency.to(device) + independent_embedding, independent_logits = independent_model(input_data) + surrogate_embedding, surrogate_logits = surrogate_model(input_data) + + softmax = torch.nn.Softmax(dim=1) + independent_prob = softmax(independent_logits) + surrogate_prob = softmax(surrogate_logits) + + if specific_nodes != None: + independent_prob = independent_prob[specific_nodes].detach().cpu() + surrogate_prob = surrogate_prob[specific_nodes].detach().cpu() + independent_embedding = independent_embedding[specific_nodes].detach().cpu() + surrogate_embedding = surrogate_embedding[specific_nodes].detach().cpu() + independent_logits = independent_logits[specific_nodes].detach().cpu() + surrogate_logits = surrogate_logits[specific_nodes].detach().cpu() + + probability = {'independent': independent_prob, 'surrogate': surrogate_prob} + embedding = {'independent': independent_embedding, 'surrogate': surrogate_embedding} + logits = {'independent': independent_logits, 'surrogate': surrogate_logits} + + return probability, logits, embedding + + +def preprocess_data_flatten(distance_pairs:list): + total_label0, total_label1 = list(), list() + + for pair_index in range(len(distance_pairs)): + label0_distance = torch.flatten(distance_pairs[pair_index]['independent']).view(1, -1) + label1_distance = torch.flatten(distance_pairs[pair_index]['surrogate']).view(1, -1) + + total_label0.append(label0_distance) + total_label1.append(label1_distance) + + processed_data = {'independent': total_label0, 'surrogate': total_label1} + + return processed_data + + +def pair_to_dataloader(distance_pairs, batch_size=5): + processed_data = preprocess_data_flatten(distance_pairs) + dataset = src.datasets.datareader.VarianceData(processed_data['independent'], processed_data['surrogate']) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + return dataloader + + +def train_original_classifier(distance_pairs: list): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + processed_data = preprocess_data_flatten(distance_pairs) + dataset = src.datasets.datareader.VarianceData(processed_data['independent'], processed_data['surrogate']) + + hidden_layers = [128, 64] + model = mlp_nn(dataset.data.shape[1], hidden_layers) + loss_fn = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + epoch_num = 1000 + + best_model, best_acc = None, 0 + for i in range(10): + dataloader = DataLoader(dataset, batch_size=10, shuffle=True) + + model.to(device) + acc = 0 + for epoch_index in range(epoch_num): + model.train() + for _, (inputs, labels) in enumerate(dataloader): + optimizer.zero_grad() + inputs = inputs.to(device) + labels = labels.to(device) + outputs = model(inputs) + loss = loss_fn(outputs, labels) + loss.backward() + optimizer.step() + + if (epoch_index + 1) % 100 == 0: + model.eval() + correct = 0 + for _, (inputs, labels) in enumerate(dataloader): + inputs = inputs.to(device) + labels = labels.to(device) + outputs = model(inputs) + _, predictions = torch.max(outputs.data, 1) + correct += (predictions == labels).sum().item() + + acc = correct / len(dataset) * 100 + + + if acc == 100: + break + + + if acc > best_acc: + best_model = model + best_acc = acc + + if best_acc == 100: + break + print("best acc:{}".format(best_acc)) + return best_model + + +def owner_verify(suspicious_logits, verifier_model): + if torch.cuda.is_available(): + device = torch.device('cuda') + else: + device = torch.device('cpu') + + distance = torch.flatten(suspicious_logits).view(1, -1) + + verifier_model.to(device) + verifier_model.eval() + + outputs = verifier_model(distance.to(device)) + _, predictions = torch.max(outputs.data, 1) + + return predictions + + +def join_path(*save_path): + original_model_save_root = os.path.join(*save_path) + if not os.path.exists(original_model_save_root): + os.makedirs(original_model_save_root) + return original_model_save_root + + +def join_name(hidden_dims): + str_dims = [str(n) for n in hidden_dims] + return "_".join(str_dims) + + +def random_generate_arch(layer_dims, num_hidden_layers, seed): + + # first generate all possible arches, then shuffle, sample + def _generate_combinations(layer_dims, num_hidden_layer): + combinations = list(itertools.product(layer_dims, repeat=num_hidden_layer)) + return [sorted(list(combination), reverse=True) for combination in combinations] + + all_hidden_dims = [] + for num_hidden_layer in num_hidden_layers: + + _hidden_dims = _generate_combinations(layer_dims, num_hidden_layer) + random.seed(seed) + random.shuffle(_hidden_dims) + all_hidden_dims.append(_hidden_dims) + + # TODO not deduplicate + res = [] + + for i in range(len(all_hidden_dims[-1])): + for j in range(len(num_hidden_layers)): + if i < len(all_hidden_dims[j]): + res.append(all_hidden_dims[j][i]) + return res + + +class GNNVerification(): + def __init__(self, args, global_cfg, train_setting_cfg, test_setting_cfg): + self.global_cfg = global_cfg + + self.test_setting_cfg = test_setting_cfg + self.train_setting_cfg = train_setting_cfg + self.args = args + self.train_save_root = os.path.join(global_cfg["train_save_root"], args.dataset, args.task_type) + self.test_save_root = os.path.join(global_cfg["test_save_root"], args.dataset, args.task_type) + # one experimental setting + + self.mask_model_save_name = "{}_{}".format(self.global_cfg["target_model"], join_name(self.global_cfg["target_hidden_dims"])) + + def train_original_model(self): + # save original model + original_model_save_root = join_path(self.train_save_root, 'original_models') + original_model_save_path = os.path.join(original_model_save_root, + "{}_{}.pt".format(self.args.benign_model, join_name(self.args.benign_hidden_dim))) + return benign.run(self.args, original_model_save_path) + + def geneate_mask_model(self): + + if self.args.task_type == "inductive": + extract_logits_data = self.original_graph_data[0] + else: + extract_logits_data = self.original_graph_data + + # generate mask model + mask_graph_data, mask_nodes = boundary.mask_graph_data(self.args, extract_logits_data, self.original_model) + mask_model_save_root = join_path(self.train_save_root, "mask_models", self.args.mask_feat_type, + "{}_{}".format(self.args.mask_node_ratio, self.args.mask_feat_ratio)) + + mask_model_save_path = os.path.join(mask_model_save_root, "{}.pt".format(self.mask_model_save_name)) + + if self.args.task_type == "inductive": + mask_graph_data = [mask_graph_data, self.original_graph_data[1], self.original_graph_data[2], self.original_graph_data[3]] + + if self.args.mask_feat_ratio == 0.0: + mask_model = copy.deepcopy(self.original_model) + torch.save(mask_model, mask_model_save_path) + mask_model_acc = self.original_model_acc + else: + _, mask_model, mask_model_acc = benign.run(self.args, mask_model_save_path, mask_graph_data) + + measure_nodes = [] + for each_class_nodes in mask_nodes: + measure_nodes += each_class_nodes + + return mask_model, mask_model_acc, measure_nodes + + # all model generate by this function will automaticly add a final layer for grove + def train_models_by_arch(self, setting_cfg, model_arch, model_save_root, seed, + mask_model_save_name=None, mask_model=None, stage="train", process="train", classifier=None): + + hidden_dims_generator = random_generate_arch(setting_cfg["layer_dims"], setting_cfg["num_hidden_layers"], + seed=seed) + + if len(hidden_dims_generator) < setting_cfg["num_model_per_arch"]: + raise Exception("Can not generate enough unique model hidden dims, please reduce num_model_per_arch") + + model_list, acc_list, fidelity_list = [], [], [] + # generate num_model_per_arch models + for hidden_dims, _ in zip(hidden_dims_generator, list(range(setting_cfg["num_model_per_arch"]))): + + # Important! add a fixed layer + hidden_dims.append(self.global_cfg["embedding_dim"]) + if mask_model is None: + # layer_dim, num_hidden_layers + self.args.benign_hidden_dim = hidden_dims + self.args.benign_model = model_arch + # train independent model + independent_model_save_root = join_path(model_save_root, 'independent_models') + independent_model_save_path = os.path.join(independent_model_save_root, + "{}_{}_{}.pt".format(stage, self.args.benign_model, + join_name(hidden_dims))) + _, model, model_acc = benign.run(self.args, independent_model_save_path, self.original_graph_data, process) + + else: + self.args.extraction_hidden_dim = hidden_dims + self.args.extraction_model = model_arch + extraction_model_save_root = join_path(model_save_root, 'extraction_models', self.args.mask_feat_type, + mask_model_save_name, + "{}_{}".format(self.args.mask_node_ratio, self.args.mask_feat_ratio)) + extraction_model_save_path = os.path.join(extraction_model_save_root, + "{}_{}_{}.pt".format(stage, self.args.extraction_model, + join_name(hidden_dims))) + model, model_acc, fidelity = extraction_runner.run(self.args, extraction_model_save_path, + self.original_graph_data, mask_model, process, classifier) + fidelity_list.append(fidelity) + + model_list.append(model) + acc_list.append(model_acc) + + return model_list, acc_list, fidelity_list + # This function train all models accroding to setting config + def train_models_by_setting(self, setting_cfg, model_save_root, mask_model_save_name=None, mask_model=None, stage="train", process="train", classifier=None): + all_model_list, all_acc_list, all_fidelity_list = [], [], [] + for seed, model_arch in enumerate(setting_cfg["model_arches"]): + model_list, acc_list, fidelity_list = self.train_models_by_arch(setting_cfg, model_arch, model_save_root, seed, mask_model_save_name, + mask_model=mask_model, stage=stage, process=process, classifier=classifier) + all_model_list += model_list + all_acc_list.append(acc_list) + + if mask_model is not None: + all_fidelity_list.append(fidelity_list) + + return all_model_list, all_acc_list, all_fidelity_list + + + def run_single_experiment(self, n_run): + save_json = {} + + start = time.time() + # train original model + self.original_graph_data, self.original_model, self.original_model_acc = self.train_original_model() + if self.args.task_type == "inductive": + extract_logits_data = self.original_graph_data[0] + else: + extract_logits_data = self.original_graph_data + + # generate mask model + mask_start = time.time() + self.mask_model, self.mask_model_acc, self.measure_nodes = self.geneate_mask_model() + mask_run_time = time.time() - mask_start + + mask_outputs, mask_logits, mask_embedding = extract_outputs(extract_logits_data, self.measure_nodes, self.mask_model, self.mask_model) + + # train independent model + train_inde_model_list, train_inde_acc_list, _ = self.train_models_by_setting(self.train_setting_cfg, self.train_save_root, + mask_model=None, stage="train", process='train') + # train surrogate model + train_surr_model_list, train_surr_acc_list, train_surr_fidelity_list = self.train_models_by_setting(self.train_setting_cfg, + self.train_save_root, self.mask_model_save_name, + self.mask_model, stage="train", process=self.global_cfg["train_process"]) + + # TODO + train_prob_list, train_logits_list, train_embedding_list = [], [],[] + for independent_model, extraction_model in zip(train_inde_model_list, train_surr_model_list): + outputs, logits, embedding = extract_outputs(extract_logits_data, self.measure_nodes, independent_model, extraction_model) + train_prob_list.append(outputs) + train_logits_list.append([mask_logits["independent"], logits["independent"], logits["surrogate"]]) + train_embedding_list.append([mask_embedding["independent"], embedding["independent"], embedding["surrogate"]]) + train_clf_start = time.time() + classifier_model = train_original_classifier(train_prob_list) + # classifier_model = train_k_fold(pair_list) + train_clf_time = time.time()-train_clf_start + + # train independent model + test_inde_model_list, test_inde_acc_list, _ = self.train_models_by_setting(self.test_setting_cfg, self.train_save_root, + mask_model=None, stage="test", process='train') + # train surrogate model + test_surr_model_list, test_surr_acc_list, test_surr_fidelity_list = self.train_models_by_setting( + self.test_setting_cfg, self.test_save_root, + self.mask_model_save_name, self.mask_model, stage="test", process=self.global_cfg["test_process"]) # classifier=classifier_model + test_logits_list, test_embedding_list = [], [] + TN, FP, FN, TP = 0, 0, 0, 0 + for test_independent_model, test_extraction_model in zip(test_inde_model_list, test_surr_model_list): + independent__outputs, test_inde_logits, test_inde_embedding = extract_outputs(extract_logits_data, self.measure_nodes, test_independent_model, test_independent_model) + surrogate_outputs, test_surr_logits, test_surr_embedding = extract_outputs(extract_logits_data, self.measure_nodes, test_extraction_model, test_extraction_model) + + ind_pred = owner_verify(independent__outputs["independent"], classifier_model) + ext_pred = owner_verify(surrogate_outputs["surrogate"], classifier_model) + + test_embedding_list.append([mask_embedding["independent"], test_inde_embedding["independent"], test_surr_embedding["surrogate"]]) + test_logits_list.append([mask_logits["independent"], test_inde_logits["independent"], test_surr_logits["surrogate"]]) + + if ind_pred == 0: + TN += 1 + else: + FP += 1 + + if ext_pred == 0: + FN += 1 + else: + TP += 1 + + + FPR = FP / (FP + TN) + FNR = FN / (FN + TP) + Accuracy = (TP + TN) / (TN + FP + TP + FN) + + # save to a + save_json["TN"], save_json["TP"] = TN, TP + save_json["FN"], save_json["FP"] = FN, FP + save_json["FPR"], save_json["FNR"] = FPR, FNR + + save_json["Accuracy"] = Accuracy + save_json["original_model_acc"] = self.original_model_acc + save_json["mask_model_acc"] = self.mask_model_acc + save_json["train_inde_acc_list"] = train_inde_acc_list + save_json["train_surr_acc_list"] = train_surr_acc_list + save_json["train_surr_fidelity_list"] = train_surr_fidelity_list + + save_json["test_inde_acc_list"] = test_inde_acc_list + save_json["test_surr_acc_list"] = test_surr_acc_list + save_json["test_surr_fidelity_list"] = test_surr_fidelity_list + + save_json["total_time"] = time.time()-start + save_json["mask_run_time"] = mask_run_time + + json_save_root = join_path(self.global_cfg["res_path"], self.args.dataset, self.args.task_type, self.args.mask_feat_type, + "{}_{}".format(self.args.mask_node_ratio, self.args.mask_feat_ratio)) + json_save_root = join_path(json_save_root,"train_setting{}".format(self.global_cfg["train_setting"]), "test_setting{}".format(self.global_cfg["test_setting"])) + + with open("{}/{}_{}.json".format(json_save_root, self.mask_model_save_name, n_run), "w") as f: + f.write(json.dumps(save_json)) + + with open("{}/train_setting.yaml".format(json_save_root), "w") as f: + yaml.dump(self.train_setting_cfg, f, default_flow_style=False) + with open("{}/test_setting.yaml".format(json_save_root), "w") as f: + yaml.dump(self.test_setting_cfg, f, default_flow_style=False) + + # if n_run == 0: + # torch.save(train_embedding_list, + # os.path.join(json_save_root, "{}_train_embedding.pkl".format(self.mask_model_save_name))) + # torch.save(test_embedding_list, + # os.path.join(json_save_root, "{}_test_embedding.pkl".format(self.mask_model_save_name))) + + # torch.save(train_logits_list, + # os.path.join(json_save_root, "{}_train_logits.pkl".format(self.mask_model_save_name))) + # torch.save(test_logits_list, + # os.path.join(json_save_root, "{}_test_logits.pkl".format(self.mask_model_save_name))) + + # print("Total Time:{}",save_json["total_time"]) + # print("Train classifier time:{}, Total time:{}, ratio:{}".format(train_clf_time, save_json["total_time"], train_clf_time/save_json["total_time"])) + + return TP, FN, TN, FP + + +def multiple_experiments(args, global_cfg, config_path=None): + if config_path is None: + config_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../config")) + target_arch_list = ["gat", "gcn", "sage"] + # target_arch_list = ["gat"] + target_hidden_dim_list = [[352, 128],[288, 128],[224, 128]] + # target_hidden_dim_list = [[224, 128]] + attack_setting_list = [1, 2, 3, 4] + + # load setting + with open(os.path.join(config_path,'train_setting{}.yaml'.format(global_cfg["train_setting"])), 'r') as file: + train_setting_cfg = yaml.safe_load(file) + + # obtain experimental parameters + grid_params = [] + for dataset in [global_cfg["dataset"]]: + for test_setting in attack_setting_list: + for target_arch in target_arch_list: + for target_hidden_dims in target_hidden_dim_list: + grid_params.append([dataset, test_setting, target_arch, target_hidden_dims]) + + + for dataset, test_setting, target_arch, target_hidden_dims in grid_params: + + # load test setting + with open(os.path.join(config_path, 'test_setting{}.yaml'.format(test_setting)), 'r') as file: + test_setting_cfg = yaml.safe_load(file) + for n_run in range(global_cfg["n_run"]): + args.dataset = dataset + args.benign_hidden_dim = target_hidden_dims + args.benign_model = target_arch + global_cfg['test_setting'] = test_setting + global_cfg['target_model'] = target_arch + global_cfg['target_hidden_dims'] = target_hidden_dims + + gnn_verification = GNNVerification(args, global_cfg, train_setting_cfg, test_setting_cfg) + gnn_verification.run_single_experiment(n_run) + + +if __name__ == '__main__': + from utils.config import parse_args + # from verification_cfg import multiple_experiments + + args = parse_args() + # ownver(args) + + with open(os.path.join("../config", "global_cfg.yaml"), 'r') as file: + global_cfg = yaml.safe_load(file) + + multiple_experiments(args, global_cfg) \ No newline at end of file diff --git a/src/src/watermarking/__init__.py b/src/src/watermarking/__init__.py new file mode 100644 index 0000000..e69de29