Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 10 additions & 43 deletions clip_benchmark/datasets/imagenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,21 @@
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder

URLS = {"matched-frequency" : "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenetv2-matched-frequency.tar.gz",
"threshold-0.7" : "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenetv2-threshold0.7.tar.gz",
"top-images": "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenetv2-top-images.tar.gz",
"val": "https://imagenetv2public.s3-us-west-2.amazonaws.com/imagenet_validation.tar.gz"}
URLS = {
"matched-frequency" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-matched-frequency.tar.gz",
"threshold-0.7" : "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-threshold0.7.tar.gz",
"top-images": "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/imagenetv2-top-images.tar.gz",
}

FNAMES = {"matched-frequency" : "imagenetv2-matched-frequency-format-val",
"threshold-0.7" : "imagenetv2-threshold0.7-format-val",
"top-images": "imagenetv2-top-images-format-val",
"val": "imagenet_validation"}
FNAMES = {
"matched-frequency" : "imagenetv2-matched-frequency-format-val",
"threshold-0.7" : "imagenetv2-threshold0.7-format-val",
"top-images": "imagenetv2-top-images-format-val",
}


V2_DATASET_SIZE = 10000
VAL_DATASET_SIZE = 50000

class ImageNetValDataset(Dataset):
def __init__(self, transform=None, location="."):
self.dataset_root = pathlib.Path(f"{location}/imagenet_validation/")
self.tar_root = pathlib.Path(f"{location}/imagenet_validation.tar.gz")
self.fnames = list(self.dataset_root.glob("**/*.JPEG"))
self.transform = transform
if not self.dataset_root.exists() or len(self.fnames) != VAL_DATASET_SIZE:
if not self.tar_root.exists():
print(f"Dataset imagenet-val not found on disk, downloading....")
response = requests.get(URLS["val"], stream=True)
total_size_in_bytes= int(response.headers.get('content-length', 0))
block_size = 1024 #1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(self.tar_root, 'wb') as f:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
f.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
assert False, f"Downloading from {URLS[variant]} failed"
print("Extracting....")
tarfile.open(self.tar_root).extractall(f"{location}")
shutil.move(f"{location}/{FNAMES['val']}", self.dataset_root)

self.dataset = ImageFolder(self.dataset_root)

def __len__(self):
return len(self.dataset)

def __getitem__(self, i):
img, label = self.dataset[i]
if self.transform is not None:
img = self.transform(img)
return img, label

class ImageNetV2Dataset(Dataset):
def __init__(self, variant="matched-frequency", transform=None, location="."):
Expand Down