From b7e931afd97497d7411683fd051ca32204794314 Mon Sep 17 00:00:00 2001 From: Mehdi Cherti Date: Wed, 11 Jun 2025 08:25:59 +0200 Subject: [PATCH] Update imagenetv2 urls (fix #122) --- clip_benchmark/datasets/imagenetv2.py | 53 +++++---------------------- 1 file changed, 10 insertions(+), 43 deletions(-) diff --git a/clip_benchmark/datasets/imagenetv2.py b/clip_benchmark/datasets/imagenetv2.py index 83885f1..5b6f5e7 100644 --- a/clip_benchmark/datasets/imagenetv2.py +++ b/clip_benchmark/datasets/imagenetv2.py @@ -12,54 +12,21 @@ from torch.utils.data import Dataset, DataLoader from torchvision.datasets import ImageFolder -URLS = {"matched-frequency" : "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenetv2-matched-frequency.tar.gz", - "threshold-0.7" : "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenetv2-threshold0.7.tar.gz", - "top-images": "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenetv2-top-images.tar.gz", - "val": "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenet_validation.tar.gz"} +URLS = { + "matched-frequency" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-matched-frequency.tar.gz", + "threshold-0.7" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-threshold0.7.tar.gz", + "top-images": "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-top-images.tar.gz", +} -FNAMES = {"matched-frequency" : "imagenetv2-matched-frequency-format-val", - "threshold-0.7" : "imagenetv2-threshold0.7-format-val", - "top-images": "imagenetv2-top-images-format-val", - "val": "imagenet_validation"} +FNAMES = { + "matched-frequency" : "imagenetv2-matched-frequency-format-val", + "threshold-0.7" : "imagenetv2-threshold0.7-format-val", + "top-images": "imagenetv2-top-images-format-val", +} V2_DATASET_SIZE = 10000 -VAL_DATASET_SIZE = 50000 -class ImageNetValDataset(Dataset): - def __init__(self, transform=None, location="."): - self.dataset_root = pathlib.Path(f"{location}/imagenet_validation/") - self.tar_root = pathlib.Path(f"{location}/imagenet_validation.tar.gz") - self.fnames = list(self.dataset_root.glob("**/*.JPEG")) - self.transform = transform - if not self.dataset_root.exists() or len(self.fnames) != VAL_DATASET_SIZE: - if not self.tar_root.exists(): - print(f"Dataset imagenet-val not found on disk, downloading....") - response = requests.get(URLS["val"], stream=True) - total_size_in_bytes= int(response.headers.get('content-length', 0)) - block_size = 1024 #1 Kibibyte - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - with open(self.tar_root, 'wb') as f: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - f.write(data) - progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - assert False, f"Downloading from {URLS[variant]} failed" - print("Extracting....") - tarfile.open(self.tar_root).extractall(f"{location}") - shutil.move(f"{location}/{FNAMES['val']}", self.dataset_root) - - self.dataset = ImageFolder(self.dataset_root) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, i): - img, label = self.dataset[i] - if self.transform is not None: - img = self.transform(img) - return img, label class ImageNetV2Dataset(Dataset): def __init__(self, variant="matched-frequency", transform=None, location="."):