Skip to content

This is the official implementation for WACV 2024 paper "Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach".

Notifications You must be signed in to change notification settings

ChangkunYe/MAPLS

Repository files navigation

Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach

This is the official implementation for WACV 2024 paper "Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach".

If you find this repository useful or use this code in your research, please cite the following paper:

@InProceedings{Ye_2024_WACV,
   author    = {Ye, Changkun and Tsuchida, Russell and Petersson, Lars and Barnes, Nick},
   title     = {Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach},
   booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
   month     = {January},
   year      = {2024},
   pages     = {1073-1082}
}

Requirements

The code is written in PyTorch. It is recommned to install via conda:

conda install scipy
conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
conda install -c conda-forge cvxpy

When train the Neural Network classifier from scratch, the recommended hardware setup is as follows:

Dataset #GPU #CPU
CIFAR10/100 ≥ 2 Gb > 4 + 1 threads
ImageNet ≥ 4 x 12 Gb > 16 + 1 threads
Places ≥ 6 x 12 Gb > 24 + 1 threads

Dataset Details

Our code support CIFAR10/100, ImageNet 2012 and Places365 datasets.

For Long-Tailed version of ImageNet and Places, please download the split at here. This split provided by Large-Scale Long-Tailed Recognition in an Open World paper.

data
  |--CIFAR10
    |--cifar-10-batches-py
  |--CIFAR100
    |--cifar-100-python
  |--ImageNet
    |--train
    |--val
    |--ImageNet_LT_train.txt
    |--ImageNet_LT_test.txt
    |--ImageNet_LT_val.txt
  |--Places
    |--data_256
    |--val_256
    |--test_256
    |--Places_LT_train.txt
    |--Places_LT_test.txt
    |--Places_LT_val.txt

P.S It is recommended to prepare ImageNet dataset with:

Extract the training data:

mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
cd ..

Extract the validation data and move images to subfolders:

mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash

This script is originally provided in here.

Train the classifier

To train the classifier from scratch, please adjust the GPU ids "CUDA_VISIBLE_DEVICES", dataset path "$data_path" and config path "$cfg_path" in the bash script "./train_script.sh" and run:

./train_script.sh

Config examples are provided in "./config/".

Test Label Shift Estimation Model

To test existing models performance under label shift, adjust the dataset path "$data_path" and checkpoint path "$ckpt_path" in the bash script "./test_script.sh" and run:

./test_script.sh

The "$cfg_path" in "./test_script.sh" determines the type of label shift, including:

  • "./config/batch_imb_LT": Ordered Long-Tailed Shift
  • "./config/batch_imb_shuffle": Shuffled Long-Tailed Shift
  • "./config/batch_imb_dirichlet": Dirichlet Shift
  • "./config/batch_imb_knockout": Knockout Shift

License

Please see LICENSE

Questions?

Pleas raise issues or contact author at changkun.ye@anu.edu.au.

About

This is the official implementation for WACV 2024 paper "Label Shift Estimation for Class-Imbalance Problem: A Bayesian Approach".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published