Official implementation for
Deep Linear Probe Generators for Weight Space Learning Jonathan Kahana, Eliahu Horwitz, Imri Shuval, Yedid Hoshen https://arxiv.org/abs/2410.10811
To run the experiments, first create a clean virtual environment and install the requirements.
conda create -n probegen python=3.9
conda activate probegen
pip install numpy pandas sklearn scipy tqdm
conda install pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidiaInstall the repo:
git clone https://https://github.com/jonkahana/ProbeGen.git
cd ProbeGenFor INR classification, we use MNIST and Fashion MNIST. The datasets are available here.
We provide the train / val / test splits as in Neural Graphs, inside this repository.
To download the data please run the following commands:
cd experiments/inr_classification/dataset
wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AABrctdu2U65jGYr2WQRzmMna/mnist-inrs.zip?dl=0" -O mnist-inrs.zip &&
mkdir -p dataset/mnist-inrs &&
unzip -q mnist-inrs.zip -d dataset &&
rm mnist-inrs.zip
cd ../../..cd experiments/inr_classification/dataset
wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AAAssoHq719OmSHSKKTiKKHGa/fmnist_inrs.zip?dl=0" -O fmnist_inrs.zip &&
mkdir -p dataset/fmnist_inrs &&
unzip -q fmnist_inrs.zip -d dataset &&
rm fmnist_inrs.zip
cd ../../..For CNN generalization, we use the grayscale CIFAR-10 (CIFAR10-GS) from the Small CNN Zoo dataset, and the CNN Wild Park. We provide the train / val / test splits as in Neural Graphs, inside this repository.
This experiment follows NFN.
Download the
CIFAR10
data (originally from Unterthiner et al,
2020)
into experiments/cnn_generalization/dataset, and extract them.
Download the dataset from Zenodo and extract it into experiments/cnn_generalization/dataset.
To download the Dead-Leaves dataset, please run the following commands:
cd experiments/cnn_generalization/dataset
bash download_dead_leaves.shTo run a specific experiment, you can use the provided scripts in the scripts directory:
The folder scripts/main_results contains the scripts to reproduce the results of ProbeGen on all 4 datasets with separate scripts for 64 and 128 probes.
For example to run ProbeGen with 128 probes use the scripts:
- MNIST INR classification:
scripts/main_results/mnist_inr__ProbeGen_128.sh - FMNIST INR classification:
scripts/main_results/fmnist_inr__ProbeGen_128.sh - CIFAR10-GS Accuracy Prediction:
scripts/main_results/cifar10_gs__ProbeGen_128.sh - CIFAR10 Wild Park Accuracy Prediction:
scripts/main_results/cifar10_wild_park__ProbeGen_128.sh
To run Vanilla Probing on a specific dataset, you can use the provided scripts in the scripts/vanilla_probing directory:
- MNIST INR classification:
scripts/vanilla_probing/mnist_inr__Vanilla_Probing_128.sh - FMNIST INR classification:
scripts/vanilla_probing/fmnist_inr__Vanilla_Probing_128.sh - CIFAR10-GS Accuracy Prediction:
scripts/vanilla_probing/cifar10_gs__Vanilla_Probing_128.sh - CIFAR10 Wild Park Accuracy Prediction:
scripts/vanilla_probing/cifar10_wild_park__Vanilla_Probing_128.sh
To run the synthetic data experiments, you can use the provided scripts in the scripts/synthetic_data directory:
- MNIST INR classification:
scripts/synthetic_data/mnist__uniform_probes_128.sh - FMNIST INR classification:
scripts/synthetic_data/fmnist__uniform_probes_128.sh - CIFAR10-GS Accuracy Prediction:
scripts/synthetic_data/cifar10_gs__Dead_Leaves_128.sh - CIFAR10 Wild Park Accuracy Prediction:
scripts/synthetic_data/cifar10_wild_park__Dead_Leaves_128.sh
If you find our work or this code to be useful in your own research, please consider citing the following paper:
@article{kahana2024deep,
title={Deep Linear Probe Generators for Weight Space Learning},
author={Kahana, Jonathan and Horwitz, Eliahu and Shuval, Imri and Hoshen, Yedid},
journal={arXiv preprint arXiv:2410.10811},
year={2024}
}- This codebase started based on https://github.com/mkofinas/neural-graphs which is originally based on github.com/AvivNavon/DWSNets