Welcome to an unofficial repository that houses a PyTorch implementation of the image retrieval network presented in the research paper, Correlation Verification for Image Retrieval [1], as well as the modifications proposed in Global Features are All You Need for Image Retrieval and Reranking [2]. An official repository for both papers do exist here and here respectively, but unfortunately, the repos lacks the comprehensive code necessary for reproducing training results due to cited intellectual property concerns.
Thus this repository has been built with the objective to bridge this gap by providing a more complete and coherent codebase. Initiative has been taken to include a well-structured easy to follow codebase as well as a clear training loop, aimed to encapsulate the full essence of the networks proposed in the original papers. This in hopes to promote more rapid and straightforward reproducibility and facilitates smoother training transitions on novel datasets.
Note that this repo is still a work in progress. See the to do list.
After cloning the repository,
git clone https://github.com/edwardguil/SuperCVNet.gitit is suggested to create a new conda env
conda create --name supercvnet python=3.12
conda activate supercvnetthen install the dependancies from the requirements.txt
pip install -r requirements.txtThe training scripts are contained in train_backbone.py and train_rerank.py. You can run these scripts from the command line, which by default starts a training loop on Cifar10:
python train_backbone.py Or by importing the training loop for more control over the inputs to the training proccess:
from train_backbone import train_backbone
train_backbone(...)CVNet is implemented into two distinct classes:
class CVNetGlobal()
pass
class CVNetRerank()
passThese models can be used like normal Pytorch models e.g.
from models import CVNetGlobal, CVNetRerank
model = CVNetGlobal()
rerank = CVNetRerank()
x = torch.rand((1, 3, 512, 512))
y = model(x)
y_ranked = rerank(y)For training, as per the paper, CVNet requires positive sample pairs to be passed through the momentum network. To simplify this proccess, you can utilize the PairedDataset class as a wrapper around existing Pytorch datasets. Note that datasets that can be anything, as long as they can be indexed (i.e. have the get_item function implemented) e.g.
from torchvision.datasets import CIFAR10
from datasets import PairedDataset()
dataset = CIFAR10()
dataset[0] # This dataset is indexable
paired_dataset = PairedDataset(dataset)
for x, x_positive, y in paired_dataset:
# Here x and x_positive share the same label (y)
passSuperGlobal is also implemented into two distinct classes:
class SuperGlobal()
pass
class SuperGlobalRerank()
passThese models can be used together or independantly like normal Pytorch models.
from models import SuperGlobal, SuperGlobalRerank
model = SuperGlobal()
rerank = SuperGlobalRerank(...)
x = torch.rand((1, 3, 512, 512))
y = model(x)
y_ranked = rerank(y)The caveat to the above, is that SuperGlobaRerank requires access to a vector database(db) for similarity search. If you simply want to perform similarity on a tensor of vectors, use the TensorVectorDB class:
from helpers import TensorVectorDB
from models import SuperGlobalRerank
vectors = torch.rand((10*3, 512)) # num vectors x feature dim
labels = torch.rand((10*3, 1)) # num vectors x label dim
vector_db = TensorVectorDB(vector_set, labels)
rerank = SuperGlobalRerank(vector_db)If you want to some other form a vector database, simply implement a child of AbstractVectorDB contained in helpers/base/vector_db.py. There already exists a pinecone_index if you want to use a Pinecone database as your vector store.
- Implement generic vectordb class to allow for easier extensability
- Implement SuperRerank network
- Complete the train_rerank script.
- Add correct transforms and class count for Google Landmarks
- Add input args for channel norms and resizing customization
[1] Lee, S., Seong, H., Lee, S., & Kim, E. (2022). Correlation Verification for Image Retrieval. 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 5364-5374. [2] Shao, S., Chen, K., Karpur, A., Cui, Q., Araújo, A.F., & Cao, B. (2023). Global Features are All You Need for Image Retrieval and Reranking. ArXiv, abs/2308.06954.