Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
.*.swp
*.swp
*.pyc
*.csv
log_k700/
log_mgh/
data_csv/
2 changes: 2 additions & 0 deletions app/vjepa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ def save_checkpoint(epoch, path):
gpu_time_meter = AverageMeter()
wall_time_meter = AverageMeter()

### Air test
print(ipe)
for itr in range(ipe):
itr_start_time = time.time()

Expand Down
91 changes: 91 additions & 0 deletions build/lib/datasets/data_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from logging import getLogger


_GLOBAL_SEED = 0
logger = getLogger()


def init_data(
batch_size,
transform=None,
shared_transform=None,
data='ImageNet',
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
copy_data=False,
drop_last=True,
tokenize_txt=True,
subset_file=None,
clip_len=8,
frame_sample_rate=2,
duration=None,
num_clips=1,
random_clip_sampling=True,
allow_clip_overlap=False,
filter_short_videos=False,
filter_long_videos=int(1e9),
decode_one_clip=True,
datasets_weights=None,
persistent_workers=False,
repeat_wds=False,
ipe=300,
log_dir=None,
):

if (data.lower() == 'imagenet') \
or (data.lower() == 'inat21') \
or (data.lower() == 'places205'):
from src.datasets.image_dataset import make_imagedataset
dataset, data_loader, dist_sampler = make_imagedataset(
transform=transform,
batch_size=batch_size,
collator=collator,
pin_mem=pin_mem,
training=training,
num_workers=num_workers,
world_size=world_size,
rank=rank,
root_path=root_path,
image_folder=image_folder,
persistent_workers=persistent_workers,
copy_data=copy_data,
drop_last=drop_last,
subset_file=subset_file)

elif data.lower() == 'videodataset':
from src.datasets.video_dataset import make_videodataset
dataset, data_loader, dist_sampler = make_videodataset(
data_paths=root_path,
batch_size=batch_size,
frames_per_clip=clip_len,
frame_step=frame_sample_rate,
duration=duration,
num_clips=num_clips,
random_clip_sampling=random_clip_sampling,
allow_clip_overlap=allow_clip_overlap,
filter_short_videos=filter_short_videos,
filter_long_videos=filter_long_videos,
shared_transform=shared_transform,
transform=transform,
datasets_weights=datasets_weights,
collator=collator,
num_workers=num_workers,
world_size=world_size,
rank=rank,
drop_last=drop_last,
log_dir=log_dir)

return (data_loader, dist_sampler)
79 changes: 79 additions & 0 deletions build/lib/datasets/image_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os

from logging import getLogger

import torch
import torchvision

_GLOBAL_SEED = 0
logger = getLogger()


class ImageFolder(torchvision.datasets.ImageFolder):

def __init__(
self,
root,
image_folder='imagenet_full_size/061417/',
transform=None,
train=True,
):
"""
ImageFolder
:param root: root network directory for ImageFolder data
:param image_folder: path to images inside root network directory
:param train: whether to load train data (or validation)
"""

suffix = 'train/' if train else 'val/'
data_path = os.path.join(root, image_folder, suffix)
logger.info(f'data-path {data_path}')
super(ImageFolder, self).__init__(root=data_path, transform=transform)
logger.info('Initialized ImageFolder')


def make_imagedataset(
transform,
batch_size,
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
copy_data=False,
drop_last=True,
persistent_workers=False,
subset_file=None
):
dataset = ImageFolder(
root=root_path,
image_folder=image_folder,
transform=transform,
train=training)
logger.info('ImageFolder dataset created')
dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dataset,
num_replicas=world_size,
rank=rank)
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
sampler=dist_sampler,
batch_size=batch_size,
drop_last=drop_last,
pin_memory=pin_mem,
num_workers=num_workers,
persistent_workers=persistent_workers)
logger.info('ImageFolder unsupervised data loader created')

return dataset, data_loader, dist_sampler
96 changes: 96 additions & 0 deletions build/lib/datasets/utils/video/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numbers
import cv2
import numpy as np
import PIL
import torch


def _is_tensor_clip(clip):
return torch.is_tensor(clip) and clip.ndimension() == 4


def crop_clip(clip, min_h, min_w, h, w):
if isinstance(clip[0], np.ndarray):
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]

elif isinstance(clip[0], PIL.Image.Image):
cropped = [
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return cropped


def resize_clip(clip, size, interpolation='bilinear'):
if isinstance(clip[0], np.ndarray):
if isinstance(size, numbers.Number):
im_h, im_w, im_c = clip[0].shape
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[0], size[1]
if interpolation == 'bilinear':
np_inter = cv2.INTER_LINEAR
else:
np_inter = cv2.INTER_NEAREST
scaled = [
cv2.resize(img, size, interpolation=np_inter) for img in clip
]
elif isinstance(clip[0], PIL.Image.Image):
if isinstance(size, numbers.Number):
im_w, im_h = clip[0].size
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
if interpolation == 'bilinear':
pil_inter = PIL.Image.BILINEAR
else:
pil_inter = PIL.Image.NEAREST
scaled = [img.resize(size, pil_inter) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return scaled


def get_resize_sizes(im_h, im_w, size):
if im_w < im_h:
ow = size
oh = int(size * im_h / im_w)
else:
oh = size
ow = int(size * im_w / im_h)
return oh, ow


def normalize(clip, mean, std, inplace=False):
if not _is_tensor_clip(clip):
raise TypeError('tensor is not a torch clip.')

if not inplace:
clip = clip.clone()

dtype = clip.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
std = torch.as_tensor(std, dtype=dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])

return clip
Loading