Skip to content

jonkahana/ProbeGen

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Deep Linear Probe Generators for Weight Space Learning

Official implementation for

Deep Linear Probe Generators for Weight Space Learning
Jonathan Kahana, Eliahu Horwitz, Imri Shuval, Yedid Hoshen
https://arxiv.org/abs/2410.10811 

arXiv

ProbeGen

Setup environment

To run the experiments, first create a clean virtual environment and install the requirements.

conda create -n probegen python=3.9
conda activate probegen
pip install numpy pandas sklearn scipy tqdm
conda install pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia

Install the repo:

git clone https://https://github.com/jonkahana/ProbeGen.git
cd ProbeGen

Datasets

INR classification

For INR classification, we use MNIST and Fashion MNIST. The datasets are available here.

We provide the train / val / test splits as in Neural Graphs, inside this repository.

To download the data please run the following commands:

MNIST

cd experiments/inr_classification/dataset
wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AABrctdu2U65jGYr2WQRzmMna/mnist-inrs.zip?dl=0" -O mnist-inrs.zip &&
  mkdir -p dataset/mnist-inrs &&
  unzip -q mnist-inrs.zip -d dataset &&
  rm mnist-inrs.zip
cd ../../..

Fashion MNIST

cd experiments/inr_classification/dataset
wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AAAssoHq719OmSHSKKTiKKHGa/fmnist_inrs.zip?dl=0" -O fmnist_inrs.zip &&
  mkdir -p dataset/fmnist_inrs &&
  unzip -q fmnist_inrs.zip -d dataset &&
  rm fmnist_inrs.zip
cd ../../..

CNN generalization

For CNN generalization, we use the grayscale CIFAR-10 (CIFAR10-GS) from the Small CNN Zoo dataset, and the CNN Wild Park. We provide the train / val / test splits as in Neural Graphs, inside this repository.

NFN CNN Zoo data

This experiment follows NFN. Download the CIFAR10 data (originally from Unterthiner et al, 2020) into experiments/cnn_generalization/dataset, and extract them.

CNN Wild Park

CNN Wild Park

Download the dataset from Zenodo and extract it into experiments/cnn_generalization/dataset.

Dead-Leaves dataset

To download the Dead-Leaves dataset, please run the following commands:

cd experiments/cnn_generalization/dataset
bash download_dead_leaves.sh

Running the experiments

To run a specific experiment, you can use the provided scripts in the scripts directory:

Main Results

The folder scripts/main_results contains the scripts to reproduce the results of ProbeGen on all 4 datasets with separate scripts for 64 and 128 probes. For example to run ProbeGen with 128 probes use the scripts:

  • MNIST INR classification: scripts/main_results/mnist_inr__ProbeGen_128.sh
  • FMNIST INR classification: scripts/main_results/fmnist_inr__ProbeGen_128.sh
  • CIFAR10-GS Accuracy Prediction: scripts/main_results/cifar10_gs__ProbeGen_128.sh
  • CIFAR10 Wild Park Accuracy Prediction: scripts/main_results/cifar10_wild_park__ProbeGen_128.sh

Vanilla Probing

To run Vanilla Probing on a specific dataset, you can use the provided scripts in the scripts/vanilla_probing directory:

  • MNIST INR classification: scripts/vanilla_probing/mnist_inr__Vanilla_Probing_128.sh
  • FMNIST INR classification: scripts/vanilla_probing/fmnist_inr__Vanilla_Probing_128.sh
  • CIFAR10-GS Accuracy Prediction: scripts/vanilla_probing/cifar10_gs__Vanilla_Probing_128.sh
  • CIFAR10 Wild Park Accuracy Prediction: scripts/vanilla_probing/cifar10_wild_park__Vanilla_Probing_128.sh

Synthetic Data

To run the synthetic data experiments, you can use the provided scripts in the scripts/synthetic_data directory:

  • MNIST INR classification: scripts/synthetic_data/mnist__uniform_probes_128.sh
  • FMNIST INR classification: scripts/synthetic_data/fmnist__uniform_probes_128.sh
  • CIFAR10-GS Accuracy Prediction: scripts/synthetic_data/cifar10_gs__Dead_Leaves_128.sh
  • CIFAR10 Wild Park Accuracy Prediction: scripts/synthetic_data/cifar10_wild_park__Dead_Leaves_128.sh

Citation

If you find our work or this code to be useful in your own research, please consider citing the following paper:

@article{kahana2024deep,
  title={Deep Linear Probe Generators for Weight Space Learning},
  author={Kahana, Jonathan and Horwitz, Eliahu and Shuval, Imri and Hoshen, Yedid},
  journal={arXiv preprint arXiv:2410.10811},
  year={2024}
}

Acknowledgments

Contributors

Releases

No releases published

Packages

No packages published