diff --git a/.gitignore b/.gitignore index 3bb2efd7..ccc7c76d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ -.*.swp +*.swp *.pyc +*.csv +log_k700/ +log_mgh/ +data_csv/ \ No newline at end of file diff --git a/app/vjepa/train.py b/app/vjepa/train.py index 2b556168..1fefcf53 100644 --- a/app/vjepa/train.py +++ b/app/vjepa/train.py @@ -376,6 +376,8 @@ def save_checkpoint(epoch, path): gpu_time_meter = AverageMeter() wall_time_meter = AverageMeter() + ### Air test + print(ipe) for itr in range(ipe): itr_start_time = time.time() diff --git a/build/lib/datasets/data_manager.py b/build/lib/datasets/data_manager.py new file mode 100644 index 00000000..cdb7ade4 --- /dev/null +++ b/build/lib/datasets/data_manager.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def init_data( + batch_size, + transform=None, + shared_transform=None, + data='ImageNet', + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + tokenize_txt=True, + subset_file=None, + clip_len=8, + frame_sample_rate=2, + duration=None, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(1e9), + decode_one_clip=True, + datasets_weights=None, + persistent_workers=False, + repeat_wds=False, + ipe=300, + log_dir=None, +): + + if (data.lower() == 'imagenet') \ + or (data.lower() == 'inat21') \ + or (data.lower() == 'places205'): + from src.datasets.image_dataset import make_imagedataset + dataset, data_loader, dist_sampler = make_imagedataset( + transform=transform, + batch_size=batch_size, + collator=collator, + pin_mem=pin_mem, + training=training, + num_workers=num_workers, + world_size=world_size, + rank=rank, + root_path=root_path, + image_folder=image_folder, + persistent_workers=persistent_workers, + copy_data=copy_data, + drop_last=drop_last, + subset_file=subset_file) + + elif data.lower() == 'videodataset': + from src.datasets.video_dataset import make_videodataset + dataset, data_loader, dist_sampler = make_videodataset( + data_paths=root_path, + batch_size=batch_size, + frames_per_clip=clip_len, + frame_step=frame_sample_rate, + duration=duration, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + shared_transform=shared_transform, + transform=transform, + datasets_weights=datasets_weights, + collator=collator, + num_workers=num_workers, + world_size=world_size, + rank=rank, + drop_last=drop_last, + log_dir=log_dir) + + return (data_loader, dist_sampler) diff --git a/build/lib/datasets/image_dataset.py b/build/lib/datasets/image_dataset.py new file mode 100644 index 00000000..84e9b082 --- /dev/null +++ b/build/lib/datasets/image_dataset.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +from logging import getLogger + +import torch +import torchvision + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class ImageFolder(torchvision.datasets.ImageFolder): + + def __init__( + self, + root, + image_folder='imagenet_full_size/061417/', + transform=None, + train=True, + ): + """ + ImageFolder + :param root: root network directory for ImageFolder data + :param image_folder: path to images inside root network directory + :param train: whether to load train data (or validation) + """ + + suffix = 'train/' if train else 'val/' + data_path = os.path.join(root, image_folder, suffix) + logger.info(f'data-path {data_path}') + super(ImageFolder, self).__init__(root=data_path, transform=transform) + logger.info('Initialized ImageFolder') + + +def make_imagedataset( + transform, + batch_size, + collator=None, + pin_mem=True, + num_workers=8, + world_size=1, + rank=0, + root_path=None, + image_folder=None, + training=True, + copy_data=False, + drop_last=True, + persistent_workers=False, + subset_file=None +): + dataset = ImageFolder( + root=root_path, + image_folder=image_folder, + transform=transform, + train=training) + logger.info('ImageFolder dataset created') + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset=dataset, + num_replicas=world_size, + rank=rank) + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=persistent_workers) + logger.info('ImageFolder unsupervised data loader created') + + return dataset, data_loader, dist_sampler diff --git a/build/lib/datasets/utils/video/functional.py b/build/lib/datasets/utils/video/functional.py new file mode 100644 index 00000000..a91d15d2 --- /dev/null +++ b/build/lib/datasets/utils/video/functional.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numbers +import cv2 +import numpy as np +import PIL +import torch + + +def _is_tensor_clip(clip): + return torch.is_tensor(clip) and clip.ndimension() == 4 + + +def crop_clip(clip, min_h, min_w, h, w): + if isinstance(clip[0], np.ndarray): + cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] + + elif isinstance(clip[0], PIL.Image.Image): + cropped = [ + img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return cropped + + +def resize_clip(clip, size, interpolation='bilinear'): + if isinstance(clip[0], np.ndarray): + if isinstance(size, numbers.Number): + im_h, im_w, im_c = clip[0].shape + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[0], size[1] + if interpolation == 'bilinear': + np_inter = cv2.INTER_LINEAR + else: + np_inter = cv2.INTER_NEAREST + scaled = [ + cv2.resize(img, size, interpolation=np_inter) for img in clip + ] + elif isinstance(clip[0], PIL.Image.Image): + if isinstance(size, numbers.Number): + im_w, im_h = clip[0].size + # Min spatial dim already matches minimal size + if (im_w <= im_h and im_w == size) or (im_h <= im_w + and im_h == size): + return clip + new_h, new_w = get_resize_sizes(im_h, im_w, size) + size = (new_w, new_h) + else: + size = size[1], size[0] + if interpolation == 'bilinear': + pil_inter = PIL.Image.BILINEAR + else: + pil_inter = PIL.Image.NEAREST + scaled = [img.resize(size, pil_inter) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return scaled + + +def get_resize_sizes(im_h, im_w, size): + if im_w < im_h: + ow = size + oh = int(size * im_h / im_w) + else: + oh = size + ow = int(size * im_w / im_h) + return oh, ow + + +def normalize(clip, mean, std, inplace=False): + if not _is_tensor_clip(clip): + raise TypeError('tensor is not a torch clip.') + + if not inplace: + clip = clip.clone() + + dtype = clip.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) + std = torch.as_tensor(std, dtype=dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + + return clip diff --git a/build/lib/datasets/utils/video/randaugment.py b/build/lib/datasets/utils/video/randaugment.py new file mode 100644 index 00000000..4c80a990 --- /dev/null +++ b/build/lib/datasets/utils/video/randaugment.py @@ -0,0 +1,518 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py +pulished under an Apache License 2.0. +""" + +import math +import numpy as np +import random +import re +import PIL +from PIL import Image, ImageEnhance, ImageOps + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split(".")[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10.0 + +_HPARAMS_DEFAULT = { + "translate_const": 250, + "img_mean": _FILL, +} + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop("resample", Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if "fillcolor" in kwargs and _PIL_VER < (5, 0): + kwargs.pop("fillcolor") + kwargs["resample"] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs + ) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs + ) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs + ) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform( + img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs + ) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + -rotn_center[1] - post_trans[1], + matrix, + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs["resample"]) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30.0 + level = _randomly_negate(level) + return (level,) + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) + + +def _enhance_increasing_level_to_arg(level, _hparams): + # the 'no change' level is 1.0, moving away from that towards 0. or 2.0 increases the enhancement blend + # range [0.1, 1.9] + level = (level / _MAX_LEVEL) * 0.9 + level = 1.0 + _randomly_negate(level) + return (level,) + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return (level,) + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams["translate_const"] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return (level,) + + +def _translate_rel_level_to_arg(level, hparams): + # default range [-0.45, 0.45] + translate_pct = hparams.get("translate_pct", 0.45) + level = (level / _MAX_LEVEL) * translate_pct + level = _randomly_negate(level) + return (level,) + + +def _posterize_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4),) + + +def _posterize_increasing_level_to_arg(level, hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image', + # intensity/severity of augmentation increases with level + return (4 - _posterize_level_to_arg(level, hparams)[0],) + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 4) + 4,) + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation decreases with level + return (int((level / _MAX_LEVEL) * 256),) + + +def _solarize_increasing_level_to_arg(level, _hparams): + # range [0, 256] + # intensity/severity of augmentation increases with level + return (256 - _solarize_level_to_arg(level, _hparams)[0],) + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return (int((level / _MAX_LEVEL) * 110),) + + +LEVEL_TO_ARG = { + "AutoContrast": None, + "Equalize": None, + "Invert": None, + "Rotate": _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + "Posterize": _posterize_level_to_arg, + "PosterizeIncreasing": _posterize_increasing_level_to_arg, + "PosterizeOriginal": _posterize_original_level_to_arg, + "Solarize": _solarize_level_to_arg, + "SolarizeIncreasing": _solarize_increasing_level_to_arg, + "SolarizeAdd": _solarize_add_level_to_arg, + "Color": _enhance_level_to_arg, + "ColorIncreasing": _enhance_increasing_level_to_arg, + "Contrast": _enhance_level_to_arg, + "ContrastIncreasing": _enhance_increasing_level_to_arg, + "Brightness": _enhance_level_to_arg, + "BrightnessIncreasing": _enhance_increasing_level_to_arg, + "Sharpness": _enhance_level_to_arg, + "SharpnessIncreasing": _enhance_increasing_level_to_arg, + "ShearX": _shear_level_to_arg, + "ShearY": _shear_level_to_arg, + "TranslateX": _translate_abs_level_to_arg, + "TranslateY": _translate_abs_level_to_arg, + "TranslateXRel": _translate_rel_level_to_arg, + "TranslateYRel": _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + "AutoContrast": auto_contrast, + "Equalize": equalize, + "Invert": invert, + "Rotate": rotate, + "Posterize": posterize, + "PosterizeIncreasing": posterize, + "PosterizeOriginal": posterize, + "Solarize": solarize, + "SolarizeIncreasing": solarize, + "SolarizeAdd": solarize_add, + "Color": color, + "ColorIncreasing": color, + "Contrast": contrast, + "ContrastIncreasing": contrast, + "Brightness": brightness, + "BrightnessIncreasing": brightness, + "Sharpness": sharpness, + "SharpnessIncreasing": sharpness, + "ShearX": shear_x, + "ShearY": shear_y, + "TranslateX": translate_x_abs, + "TranslateY": translate_y_abs, + "TranslateXRel": translate_x_rel, + "TranslateYRel": translate_y_rel, +} + + +class AugmentOp: + """ + Apply for video. + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = { + "fillcolor": hparams["img_mean"] + if "img_mean" in hparams + else _FILL, + "resample": hparams["interpolation"] + if "interpolation" in hparams + else _RANDOM_INTERPOLATION, + } + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get("magnitude_std", 0) + + def __call__(self, img_list): + if self.prob < 1.0 and random.random() > self.prob: + return img_list + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = ( + self.level_fn(magnitude, self.hparams) + if self.level_fn is not None + else () + ) + + if isinstance(img_list, list): + return [ + self.aug_fn(img, *level_args, **self.kwargs) for img in img_list + ] + else: + return self.aug_fn(img_list, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "Posterize", + "Solarize", + "SolarizeAdd", + "Color", + "Contrast", + "Brightness", + "Sharpness", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +_RAND_INCREASING_TRANSFORMS = [ + "AutoContrast", + "Equalize", + "Invert", + "Rotate", + "PosterizeIncreasing", + "SolarizeIncreasing", + "SolarizeAdd", + "ColorIncreasing", + "ContrastIncreasing", + "BrightnessIncreasing", + "SharpnessIncreasing", + "ShearX", + "ShearY", + "TranslateXRel", + "TranslateYRel", +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + "Rotate": 0.3, + "ShearX": 0.2, + "ShearY": 0.2, + "TranslateXRel": 0.1, + "TranslateYRel": 0.1, + "Color": 0.025, + "Sharpness": 0.025, + "AutoContrast": 0.025, + "Solarize": 0.005, + "SolarizeAdd": 0.005, + "Contrast": 0.005, + "Brightness": 0.005, + "Equalize": 0.005, + "Posterize": 0, + "Invert": 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [ + AugmentOp(name, prob=0.5, magnitude=magnitude, hparams=hparams) + for name in transforms + ] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, + self.num_layers, + replace=self.choice_weights is None, + p=self.choice_weights, + ) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams): + """ + RandAugment: Practical automated data augmentation... - https://arxiv.org/abs/1909.13719 + + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order sepecific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + 'inc' - integer (bool), use augmentations that increase in severity with magnitude (default: 0) + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + transforms = _RAND_TRANSFORMS + config = config_str.split("-") + assert config[0] == "rand" + config = config[1:] + for c in config: + cs = re.split(r"(\d.*)", c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == "mstd": + # noise param injected via hparams for now + hparams.setdefault("magnitude_std", float(val)) + elif key == "inc": + if bool(val): + transforms = _RAND_INCREASING_TRANSFORMS + elif key == "m": + magnitude = int(val) + elif key == "n": + num_layers = int(val) + elif key == "w": + weight_idx = int(val) + else: + assert NotImplementedError + ra_ops = rand_augment_ops( + magnitude=magnitude, hparams=hparams, transforms=transforms + ) + choice_weights = ( + None if weight_idx is None else _select_rand_weights(weight_idx) + ) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) diff --git a/build/lib/datasets/utils/video/randerase.py b/build/lib/datasets/utils/video/randerase.py new file mode 100644 index 00000000..d1f185c8 --- /dev/null +++ b/build/lib/datasets/utils/video/randerase.py @@ -0,0 +1,180 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +This implementation is based on +https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py +pulished under an Apache License 2.0. +""" +import math +import random +import torch + + +def _get_pixels( + per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" +): + # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() + # paths, flip the order so normal is run on CPU if this becomes a problem + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + if per_pixel: + return torch.empty(patch_size, dtype=dtype, device=device).normal_() + elif rand_color: + return torch.empty( + (patch_size[0], 1, 1), dtype=dtype, device=device + ).normal_() + else: + return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) + + +class RandomErasing: + """Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed. + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + mode: pixel color mode, one of 'const', 'rand', or 'pixel' + 'const' - erase block is constant color of 0 for all channels + 'rand' - erase block is same per-channel random (normal) color + 'pixel' - erase block is per-pixel random (normal) color + max_count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is randomly chosen between 1 and this value. + """ + + def __init__( + self, + probability=0.5, + min_area=0.02, + max_area=1 / 3, + min_aspect=0.3, + max_aspect=None, + mode="const", + min_count=1, + max_count=None, + num_splits=0, + device="cuda", + cube=True, + ): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.min_count = min_count + self.max_count = max_count or min_count + self.num_splits = num_splits + mode = mode.lower() + self.rand_color = False + self.per_pixel = False + self.cube = cube + if mode == "rand": + self.rand_color = True # per block random normal + elif mode == "pixel": + self.per_pixel = True # per pixel random normal + else: + assert not mode or mode == "const" + self.device = device + + def _erase(self, img, chan, img_h, img_w, dtype): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(10): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + img[:, top:top + h, left:left + w] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def _erase_cube( + self, + img, + batch_start, + batch_size, + chan, + img_h, + img_w, + dtype, + ): + if random.random() > self.probability: + return + area = img_h * img_w + count = ( + self.min_count + if self.min_count == self.max_count + else random.randint(self.min_count, self.max_count) + ) + for _ in range(count): + for _ in range(100): + target_area = ( + random.uniform(self.min_area, self.max_area) * area / count + ) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + for i in range(batch_start, batch_size): + img_instance = img[i] + img_instance[ + :, top:top + h, left:left + w + ] = _get_pixels( + self.per_pixel, + self.rand_color, + (chan, h, w), + dtype=dtype, + device=self.device, + ) + break + + def __call__(self, input): + if len(input.size()) == 3: + self._erase(input, *input.size(), input.dtype) + else: + batch_size, chan, img_h, img_w = input.size() + # skip first slice of batch if num_splits is set (for clean portion of samples) + batch_start = ( + batch_size // self.num_splits if self.num_splits > 1 else 0 + ) + if self.cube: + self._erase_cube( + input, + batch_start, + batch_size, + chan, + img_h, + img_w, + input.dtype, + ) + else: + for i in range(batch_start, batch_size): + self._erase(input[i], chan, img_h, img_w, input.dtype) + return input diff --git a/build/lib/datasets/utils/video/transforms.py b/build/lib/datasets/utils/video/transforms.py new file mode 100644 index 00000000..ffa8e61d --- /dev/null +++ b/build/lib/datasets/utils/video/transforms.py @@ -0,0 +1,1184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +import numpy as np +import random +import numbers +import PIL +from PIL import Image + +import torch +import torchvision +import torchvision.transforms.functional as F +from torchvision import transforms + +import src.datasets.utils.video.functional as FF +from src.datasets.utils.video.randaugment import rand_augment_transform + + +_pil_interpolation_to_str = { + Image.NEAREST: 'PIL.Image.NEAREST', + Image.BILINEAR: 'PIL.Image.BILINEAR', + Image.BICUBIC: 'PIL.Image.BICUBIC', + Image.LANCZOS: 'PIL.Image.LANCZOS', + Image.HAMMING: 'PIL.Image.HAMMING', + Image.BOX: 'PIL.Image.BOX', +} + + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _pil_interp(method): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + return Image.BILINEAR + + +def random_short_side_scale_jitter( + images, min_size, max_size, boxes=None, inverse_uniform_sampling=False +): + """ + Perform a spatial short scale jittering on the given images and + corresponding boxes. + Args: + images (tensor): images to perform scale jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + min_size (int): the minimal size to scale the frames. + max_size (int): the maximal size to scale the frames. + boxes (ndarray): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + inverse_uniform_sampling (bool): if True, sample uniformly in + [1 / max_scale, 1 / min_scale] and take a reciprocal to get the + scale. If False, take a uniform sample from [min_scale, max_scale]. + Returns: + (tensor): the scaled images with dimension of + `num frames` x `channel` x `new height` x `new width`. + (ndarray or None): the scaled boxes with dimension of + `num boxes` x 4. + """ + if inverse_uniform_sampling: + size = int( + round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size)) + ) + else: + size = int(round(np.random.uniform(min_size, max_size))) + + height = images.shape[2] + width = images.shape[3] + if (width <= height and width == size) or ( + height <= width and height == size + ): + return images, boxes + new_width = size + new_height = size + if width < height: + new_height = int(math.floor((float(height) / width) * size)) + if boxes is not None: + boxes = boxes * float(new_height) / height + else: + new_width = int(math.floor((float(width) / height) * size)) + if boxes is not None: + boxes = boxes * float(new_width) / width + + return ( + torch.nn.functional.interpolate( + images, + size=(new_height, new_width), + mode='bilinear', + align_corners=False, + ), + boxes, + ) + + +def crop_boxes(boxes, x_offset, y_offset): + """ + Peform crop on the bounding boxes given the offsets. + Args: + boxes (ndarray or None): bounding boxes to peform crop. The dimension + is `num boxes` x 4. + x_offset (int): cropping offset in the x axis. + y_offset (int): cropping offset in the y axis. + Returns: + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + cropped_boxes = boxes.copy() + cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset + cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset + + return cropped_boxes + + +def random_crop(images, size, boxes=None): + """ + Perform random spatial crop on the given images and corresponding boxes. + Args: + images (tensor): images to perform random crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): the size of height and width to crop on the image. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + cropped (tensor): cropped images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + if images.shape[2] == size and images.shape[3] == size: + return images + height = images.shape[2] + width = images.shape[3] + y_offset = 0 + if height > size: + y_offset = int(np.random.randint(0, height - size)) + x_offset = 0 + if width > size: + x_offset = int(np.random.randint(0, width - size)) + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + + return cropped, cropped_boxes + + +def horizontal_flip(prob, images, boxes=None): + """ + Perform horizontal flip on the given images and corresponding boxes. + Args: + prob (float): probility to flip the images. + images (tensor): images to perform horizontal flip, the dimension is + `num frames` x `channel` x `height` x `width`. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + Returns: + images (tensor): images with dimension of + `num frames` x `channel` x `height` x `width`. + flipped_boxes (ndarray or None): the flipped boxes with dimension of + `num boxes` x 4. + """ + if boxes is None: + flipped_boxes = None + else: + flipped_boxes = boxes.copy() + + if np.random.uniform() < prob: + images = images.flip((-1)) + + if len(images.shape) == 3: + width = images.shape[2] + elif len(images.shape) == 4: + width = images.shape[3] + else: + raise NotImplementedError("Dimension does not supported") + if boxes is not None: + flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1 + + return images, flipped_boxes + + +def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): + """ + Perform uniform spatial sampling on the images and corresponding boxes. + Args: + images (tensor): images to perform uniform crop. The dimension is + `num frames` x `channel` x `height` x `width`. + size (int): size of height and weight to crop the images. + spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width + is larger than height. Or 0, 1, or 2 for top, center, and bottom + crop if height is larger than width. + boxes (ndarray or None): optional. Corresponding boxes to images. + Dimension is `num boxes` x 4. + scale_size (int): optinal. If not None, resize the images to scale_size before + performing any crop. + Returns: + cropped (tensor): images with dimension of + `num frames` x `channel` x `size` x `size`. + cropped_boxes (ndarray or None): the cropped boxes with dimension of + `num boxes` x 4. + """ + assert spatial_idx in [0, 1, 2] + ndim = len(images.shape) + if ndim == 3: + images = images.unsqueeze(0) + height = images.shape[2] + width = images.shape[3] + + if scale_size is not None: + if width <= height: + width, height = scale_size, int(height / width * scale_size) + else: + width, height = int(width / height * scale_size), scale_size + images = torch.nn.functional.interpolate( + images, + size=(height, width), + mode='bilinear', + align_corners=False, + ) + + y_offset = int(math.ceil((height - size) / 2)) + x_offset = int(math.ceil((width - size) / 2)) + + if height > width: + if spatial_idx == 0: + y_offset = 0 + elif spatial_idx == 2: + y_offset = height - size + else: + if spatial_idx == 0: + x_offset = 0 + elif spatial_idx == 2: + x_offset = width - size + cropped = images[ + :, :, y_offset:y_offset + size, x_offset:x_offset + size + ] + cropped_boxes = ( + crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None + ) + if ndim == 3: + cropped = cropped.squeeze(0) + return cropped, cropped_boxes + + +def clip_boxes_to_image(boxes, height, width): + """ + Clip an array of boxes to an image with the given height and width. + Args: + boxes (ndarray): bounding boxes to perform clipping. + Dimension is `num boxes` x 4. + height (int): given image height. + width (int): given image width. + Returns: + clipped_boxes (ndarray): the clipped boxes with dimension of + `num boxes` x 4. + """ + clipped_boxes = boxes.copy() + clipped_boxes[:, [0, 2]] = np.minimum( + width - 1.0, np.maximum(0.0, boxes[:, [0, 2]]) + ) + clipped_boxes[:, [1, 3]] = np.minimum( + height - 1.0, np.maximum(0.0, boxes[:, [1, 3]]) + ) + return clipped_boxes + + +def blend(images1, images2, alpha): + """ + Blend two images with a given weight alpha. + Args: + images1 (tensor): the first images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + images2 (tensor): the second images to be blended, the dimension is + `num frames` x `channel` x `height` x `width`. + alpha (float): the blending weight. + Returns: + (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + return images1 * alpha + images2 * (1 - alpha) + + +def grayscale(images): + """ + Get the grayscale for the input images. The channels of images should be + in order BGR. + Args: + images (tensor): the input images for getting grayscale. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + img_gray (tensor): blended images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + # R -> 0.299, G -> 0.587, B -> 0.114. + img_gray = torch.tensor(images) + gray_channel = ( + 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0] + ) + img_gray[:, 0] = gray_channel + img_gray[:, 1] = gray_channel + img_gray[:, 2] = gray_channel + return img_gray + + +def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0): + """ + Perfrom a color jittering on the input images. The channels of images + should be in order BGR. + Args: + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + img_brightness (float): jitter ratio for brightness. + img_contrast (float): jitter ratio for contrast. + img_saturation (float): jitter ratio for saturation. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + + jitter = [] + if img_brightness != 0: + jitter.append('brightness') + if img_contrast != 0: + jitter.append('contrast') + if img_saturation != 0: + jitter.append('saturation') + + if len(jitter) > 0: + order = np.random.permutation(np.arange(len(jitter))) + for idx in range(0, len(jitter)): + if jitter[order[idx]] == 'brightness': + images = brightness_jitter(img_brightness, images) + elif jitter[order[idx]] == 'contrast': + images = contrast_jitter(img_contrast, images) + elif jitter[order[idx]] == 'saturation': + images = saturation_jitter(img_saturation, images) + return images + + +def brightness_jitter(var, images): + """ + Perfrom brightness jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for brightness. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_bright = torch.zeros(images.shape) + images = blend(images, img_bright, alpha) + return images + + +def contrast_jitter(var, images): + """ + Perfrom contrast jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for contrast. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + + img_gray = grayscale(images) + img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True) + images = blend(images, img_gray, alpha) + return images + + +def saturation_jitter(var, images): + """ + Perfrom saturation jittering on the input images. The channels of images + should be in order BGR. + Args: + var (float): jitter ratio for saturation. + images (tensor): images to perform color jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + Returns: + images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + alpha = 1.0 + np.random.uniform(-var, var) + img_gray = grayscale(images) + images = blend(images, img_gray, alpha) + + return images + + +def lighting_jitter(images, alphastd, eigval, eigvec): + """ + Perform AlexNet-style PCA jitter on the given images. + Args: + images (tensor): images to perform lighting jitter. Dimension is + `num frames` x `channel` x `height` x `width`. + alphastd (float): jitter ratio for PCA jitter. + eigval (list): eigenvalues for PCA jitter. + eigvec (list[list]): eigenvectors for PCA jitter. + Returns: + out_images (tensor): the jittered images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if alphastd == 0: + return images + # generate alpha1, alpha2, alpha3. + alpha = np.random.normal(0, alphastd, size=(1, 3)) + eig_vec = np.array(eigvec) + eig_val = np.reshape(eigval, (1, 3)) + rgb = np.sum( + eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), + axis=1, + ) + out_images = torch.zeros_like(images) + if len(images.shape) == 3: + # C H W + channel_dim = 0 + elif len(images.shape) == 4: + # T C H W + channel_dim = 1 + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + for idx in range(images.shape[channel_dim]): + # C H W + if len(images.shape) == 3: + out_images[idx] = images[idx] + rgb[2 - idx] + # T C H W + elif len(images.shape) == 4: + out_images[:, idx] = images[:, idx] + rgb[2 - idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + + return out_images + + +def color_normalization(images, mean, stddev): + """ + Perform color nomration on the given images. + Args: + images (tensor): images to perform color normalization. Dimension is + `num frames` x `channel` x `height` x `width`. + mean (list): mean values for normalization. + stddev (list): standard deviations for normalization. + + Returns: + out_images (tensor): the noramlized images, the dimension is + `num frames` x `channel` x `height` x `width`. + """ + if len(images.shape) == 3: + assert ( + len(mean) == images.shape[0] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[0] + ), 'channel stddev not computed properly' + elif len(images.shape) == 4: + assert ( + len(mean) == images.shape[1] + ), 'channel mean not computed properly' + assert ( + len(stddev) == images.shape[1] + ), 'channel stddev not computed properly' + else: + raise NotImplementedError(f'Unsupported dimension {len(images.shape)}') + + out_images = torch.zeros_like(images) + for idx in range(len(mean)): + # C H W + if len(images.shape) == 3: + out_images[idx] = (images[idx] - mean[idx]) / stddev[idx] + elif len(images.shape) == 4: + out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx] + else: + raise NotImplementedError( + f'Unsupported dimension {len(images.shape)}' + ) + return out_images + + +def _get_param_spatial_crop( + scale, ratio, height, width, num_repeat=10, log_scale=True, switch_hw=False +): + """ + Given scale, ratio, height and width, return sampled coordinates of the videos. + """ + for _ in range(num_repeat): + area = height * width + target_area = random.uniform(*scale) * area + if log_scale: + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + else: + aspect_ratio = random.uniform(*ratio) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if np.random.uniform() < 0.5 and switch_hw: + w, h = h, w + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + +def random_resized_crop( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + Crop the given images to random size and aspect ratio. A crop of random + size (default: of 0.08 to 1.0) of the original size and a random aspect + ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This + crop is finally resized to given size. This is popularly used to train the + Inception networks. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + cropped = images[:, :, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped, + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + + +def random_resized_crop_with_shift( + images, + target_height, + target_width, + scale=(0.8, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), +): + """ + This is similar to random_resized_crop. However, it samples two different + boxes (for cropping) for the first and last frame. It then linearly + interpolates the two boxes for other frames. + + Args: + images: Images to perform resizing and cropping. + target_height: Desired height after cropping. + target_width: Desired width after cropping. + scale: Scale range of Inception-style area based random resizing. + ratio: Aspect ratio range of Inception-style area based random resizing. + """ + t = images.shape[1] + height = images.shape[2] + width = images.shape[3] + + i, j, h, w = _get_param_spatial_crop(scale, ratio, height, width) + i_, j_, h_, w_ = _get_param_spatial_crop(scale, ratio, height, width) + i_s = [int(i) for i in torch.linspace(i, i_, steps=t).tolist()] + j_s = [int(i) for i in torch.linspace(j, j_, steps=t).tolist()] + h_s = [int(i) for i in torch.linspace(h, h_, steps=t).tolist()] + w_s = [int(i) for i in torch.linspace(w, w_, steps=t).tolist()] + out = torch.zeros((3, t, target_height, target_width)) + for ind in range(t): + out[:, ind:ind + 1, :, :] = torch.nn.functional.interpolate( + images[ + :, + ind:ind + 1, + i_s[ind]:i_s[ind] + h_s[ind], + j_s[ind]:j_s[ind] + w_s[ind], + ], + size=(target_height, target_width), + mode='bilinear', + align_corners=False, + ) + return out + + +def create_random_augment( + input_size, + auto_augment=None, + interpolation='bilinear', +): + """ + Get video randaug transform. + + Args: + input_size: The size of the input video in tuple. + auto_augment: Parameters for randaug. An example: + "rand-m7-n4-mstd0.5-inc1" (m is the magnitude and n is the number + of operations to apply). + interpolation: Interpolation method. + """ + if isinstance(input_size, tuple): + img_size = input_size[-2:] + else: + img_size = input_size + + if auto_augment: + assert isinstance(auto_augment, str) + if isinstance(img_size, tuple): + img_size_min = min(img_size) + else: + img_size_min = img_size + aa_params = {'translate_const': int(img_size_min * 0.45)} + if interpolation and interpolation != 'random': + aa_params['interpolation'] = _pil_interp(interpolation) + if auto_augment.startswith('rand'): + return transforms.Compose( + [rand_augment_transform(auto_augment, aa_params)] + ) + raise NotImplementedError + + +def random_sized_crop_img( + im, + size, + jitter_scale=(0.08, 1.0), + jitter_aspect=(3.0 / 4.0, 4.0 / 3.0), + max_iter=10, +): + """ + Performs Inception-style cropping (used for training). + """ + assert ( + len(im.shape) == 3 + ), 'Currently only support image for random_sized_crop' + h, w = im.shape[1:3] + i, j, h, w = _get_param_spatial_crop( + scale=jitter_scale, + ratio=jitter_aspect, + height=h, + width=w, + num_repeat=max_iter, + log_scale=False, + switch_hw=True, + ) + cropped = im[:, i:i + h, j:j + w] + return torch.nn.functional.interpolate( + cropped.unsqueeze(0), + size=(size, size), + mode='bilinear', + align_corners=False, + ).squeeze(0) + + +# The following code are modified based on timm lib, we will replace the following +# contents with dependency from PyTorchVideo. +# https://github.com/facebookresearch/pytorchvideo +class RandomResizedCropAndInterpolation: + """Crop the given PIL Image to random size and aspect ratio with random interpolation. + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation='bilinear', + ): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + print('range should be of kind (min, max)') + + if interpolation == 'random': + self.interpolation = _RANDOM_INTERPOLATION + else: + self.interpolation = _pil_interp(interpolation) + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + area = img.size[0] * img.size[1] + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if w <= img.size[0] and h <= img.size[1]: + i = random.randint(0, img.size[1] - h) + j = random.randint(0, img.size[0] - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = img.size[0] / img.size[1] + if in_ratio < min(ratio): + w = img.size[0] + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = img.size[1] + w = int(round(h * max(ratio))) + else: # whole image + w = img.size[0] + h = img.size[1] + i = (img.size[1] - h) // 2 + j = (img.size[0] - w) // 2 + return i, j, h, w + + def __call__(self, img): + """ + Args: + img (PIL Image): Image to be cropped and resized. + Returns: + PIL Image: Randomly cropped and resized image. + """ + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if isinstance(self.interpolation, (tuple, list)): + interpolation = random.choice(self.interpolation) + else: + interpolation = self.interpolation + return F.resized_crop(img, i, j, h, w, self.size, interpolation) + + def __repr__(self): + if isinstance(self.interpolation, (tuple, list)): + interpolate_str = ' '.join( + [_pil_interpolation_to_str[x] for x in self.interpolation] + ) + else: + interpolate_str = _pil_interpolation_to_str[self.interpolation] + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format( + tuple(round(s, 4) for s in self.scale) + ) + format_string += ', ratio={0}'.format( + tuple(round(r, 4) for r in self.ratio) + ) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +class Compose(object): + """Composes several transforms + Args: + transforms (list of ``Transform`` objects): list of transforms + to compose + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, clip): + for t in self.transforms: + clip = t(clip) + return clip + + +class RandomHorizontalFlip(object): + """Horizontally flip the list of given images randomly + with a probability 0.5 + """ + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Randomly flipped clip + """ + if random.random() < 0.5: + if isinstance(clip[0], np.ndarray): + return [np.fliplr(img) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + return [ + img.transpose(PIL.Image.FLIP_LEFT_RIGHT) for img in clip + ] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + ' but got list of {0}'.format(type(clip[0]))) + return clip + + +class RandomResize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, ratio=(3. / 4., 4. / 3.), interpolation='nearest'): + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, clip): + scaling_factor = random.uniform(self.ratio[0], self.ratio[1]) + + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + + new_w = int(im_w * scaling_factor) + new_h = int(im_h * scaling_factor) + new_size = (new_w, new_h) + resized = FF.resize_clip( + clip, new_size, interpolation=self.interpolation) + return resized + + +class Resize(object): + """Resizes a list of (H x W x C) numpy.ndarray to the final size + The larger the original image is, the more times it takes to + interpolate + Args: + interpolation (str): Can be one of 'nearest', 'bilinear' + defaults to nearest + size (tuple): (widht, height) + """ + + def __init__(self, size, interpolation='nearest'): + self.size = size + self.interpolation = interpolation + + def __call__(self, clip): + resized = FF.resize_clip( + clip, self.size, interpolation=self.interpolation) + return resized + + +class RandomCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = random.randint(0, im_w - w) + y1 = random.randint(0, im_h - h) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ThreeCrop(object): + """Extract random crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w != im_w and h != im_h: + clip = FF.resize_clip(clip, self.size, interpolation="bilinear") + im_h, im_w, im_c = clip[0].shape + + step = np.max((np.max((im_w, im_h)) - self.size[0]) // 2, 0) + cropped = [] + for i in range(3): + if (im_h > self.size[0]): + x1 = 0 + y1 = i * step + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + else: + x1 = i * step + y1 = 0 + cropped.extend(FF.crop_clip(clip, y1, x1, h, w)) + return cropped + + +class RandomRotation(object): + """Rotate entire clip randomly by a random angle within + given bounds + Args: + degrees (sequence or int): Range of degrees to select from + If degrees is a number instead of sequence like (min, max), + the range of degrees, will be (-degrees, +degrees). + """ + + def __init__(self, degrees): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError('If degrees is a single number,' + 'must be positive') + degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError('If degrees is a sequence,' + 'it must be of len 2.') + + self.degrees = degrees + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + import skimage + angle = random.uniform(self.degrees[0], self.degrees[1]) + if isinstance(clip[0], np.ndarray): + rotated = [skimage.transform.rotate(img, angle) for img in clip] + elif isinstance(clip[0], PIL.Image.Image): + rotated = [img.rotate(angle) for img in clip] + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + + return rotated + + +class CenterCrop(object): + """Extract center crop at the same location for a list of images + Args: + size (sequence or int): Desired output size for the + crop in format (h, w) + """ + + def __init__(self, size): + if isinstance(size, numbers.Number): + size = (size, size) + + self.size = size + + def __call__(self, clip): + """ + Args: + img (PIL.Image or numpy.ndarray): List of images to be cropped + in format (h, w, c) in numpy.ndarray + Returns: + PIL.Image or numpy.ndarray: Cropped list of images + """ + h, w = self.size + if isinstance(clip[0], np.ndarray): + im_h, im_w, im_c = clip[0].shape + elif isinstance(clip[0], PIL.Image.Image): + im_w, im_h = clip[0].size + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + if w > im_w or h > im_h: + error_msg = ( + 'Initial image size should be larger then ' + 'cropped size but got cropped sizes : ({w}, {h}) while ' + 'initial image is ({im_w}, {im_h})'.format( + im_w=im_w, im_h=im_h, w=w, h=h)) + raise ValueError(error_msg) + + x1 = int(round((im_w - w) / 2.)) + y1 = int(round((im_h - h) / 2.)) + cropped = FF.crop_clip(clip, y1, x1, h, w) + + return cropped + + +class ColorJitter(object): + """ + Randomly change the brightness, contrast and saturation and hue of the clip + + Args: + brightness (float): How much to jitter brightness. brightness_factor + is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. + contrast (float): How much to jitter contrast. contrast_factor + is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. + saturation (float): How much to jitter saturation. saturation_factor + is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. + hue(float): How much to jitter hue. hue_factor is chosen uniformly from + [-hue, hue]. Should be >=0 and <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = brightness + self.contrast = contrast + self.saturation = saturation + self.hue = hue + + def get_params(self, brightness, contrast, saturation, hue): + if brightness > 0: + brightness_factor = random.uniform( + max(0, 1 - brightness), 1 + brightness) + else: + brightness_factor = None + + if contrast > 0: + contrast_factor = random.uniform( + max(0, 1 - contrast), 1 + contrast) + else: + contrast_factor = None + + if saturation > 0: + saturation_factor = random.uniform( + max(0, 1 - saturation), 1 + saturation) + else: + saturation_factor = None + + if hue > 0: + hue_factor = random.uniform(-hue, hue) + else: + hue_factor = None + return brightness_factor, contrast_factor, saturation_factor, hue_factor + + def __call__(self, clip): + """ + Args: + clip (list): list of PIL.Image + Returns: + list PIL.Image : list of transformed PIL.Image + """ + if isinstance(clip[0], np.ndarray): + raise TypeError( + 'Color jitter not yet implemented for numpy arrays') + elif isinstance(clip[0], PIL.Image.Image): + brightness, contrast, saturation, hue = self.get_params( + self.brightness, self.contrast, self.saturation, self.hue) + + # Create img transform function sequence + img_transforms = [] + if brightness is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_brightness(img, brightness)) + if saturation is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_saturation(img, saturation)) + if hue is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_hue(img, hue)) + if contrast is not None: + img_transforms.append(lambda img: torchvision.transforms.functional.adjust_contrast(img, contrast)) + random.shuffle(img_transforms) + + # Apply to all images + jittered_clip = [] + for img in clip: + for func in img_transforms: + jittered_img = func(img) + jittered_clip.append(jittered_img) + + else: + raise TypeError('Expected numpy.ndarray or PIL.Image' + + 'but got list of {0}'.format(type(clip[0]))) + return jittered_clip + + +class Normalize(object): + """Normalize a clip with mean and standard deviation. + Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform + will normalize each channel of the input ``torch.*Tensor`` i.e. + ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` + .. note:: + This transform acts out of place, i.e., it does not mutates the input tensor. + Args: + mean (sequence): Sequence of means for each channel. + std (sequence): Sequence of standard deviations for each channel. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, clip): + """ + Args: + clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor clip. + """ + return FF.normalize(clip, self.mean, self.std) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) diff --git a/build/lib/datasets/utils/video/volume_transforms.py b/build/lib/datasets/utils/video/volume_transforms.py new file mode 100644 index 00000000..0a01bb36 --- /dev/null +++ b/build/lib/datasets/utils/video/volume_transforms.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np +from PIL import Image + +import torch + + +def convert_img(img): + """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" + if len(img.shape) == 3: + img = img.transpose(2, 0, 1) + if len(img.shape) == 2: + img = np.expand_dims(img, 0) + return img + + +class ClipToTensor(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = np_clip / 255.0 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(tensor_clip, 255) + return tensor_clip + + +# Note this norms data to -1/1 +class ClipToTensor_K(object): + """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] + to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] + """ + + def __init__(self, channel_nb=3, div_255=True, numpy=False): + self.channel_nb = channel_nb + self.div_255 = div_255 + self.numpy = numpy + + def __call__(self, clip): + """ + Args: clip (list of numpy.ndarray): clip (list of images) + to be converted to tensor. + """ + # Retrieve shape + if isinstance(clip[0], np.ndarray): + h, w, ch = clip[0].shape + assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) + elif isinstance(clip[0], Image.Image): + w, h = clip[0].size + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + + np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) + + # Convert + for img_idx, img in enumerate(clip): + if isinstance(img, np.ndarray): + pass + elif isinstance(img, Image.Image): + img = np.array(img, copy=False) + else: + raise TypeError( + "Expected numpy.ndarray or PIL.Image\ + but got list of {0}".format( + type(clip[0]) + ) + ) + img = convert_img(img) + np_clip[:, img_idx, :, :] = img + if self.numpy: + if self.div_255: + np_clip = (np_clip - 127.5) / 127.5 + return np_clip + + else: + tensor_clip = torch.from_numpy(np_clip) + + if not isinstance(tensor_clip, torch.FloatTensor): + tensor_clip = tensor_clip.float() + if self.div_255: + tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) + return tensor_clip + + +class ToTensor(object): + """Converts numpy array to tensor""" + + def __call__(self, array): + tensor = torch.from_numpy(array) + return tensor diff --git a/build/lib/datasets/utils/weighted_sampler.py b/build/lib/datasets/utils/weighted_sampler.py new file mode 100644 index 00000000..fd40825e --- /dev/null +++ b/build/lib/datasets/utils/weighted_sampler.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from typing import Iterator, Optional +from operator import itemgetter +import numpy as np + +import torch +from torch.utils.data import ( + Dataset, + Sampler, + DistributedSampler, + WeightedRandomSampler +) + + +class DatasetFromSampler(Dataset): + + def __init__(self, sampler: Sampler): + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ Convert any Pytorch Sampler to a DistributedSampler """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +class CustomWeightedRandomSampler(WeightedRandomSampler): + """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __iter__(self): + rand_tensor = np.random.choice( + range(0, len(self.weights)), + size=self.num_samples, + p=self.weights.numpy() / torch.sum(self.weights).numpy(), + replace=self.replacement + ) + rand_tensor = torch.from_numpy(rand_tensor) + return iter(rand_tensor.tolist()) + + +class DistributedWeightedSampler(DistributedSamplerWrapper): + + def __init__( + self, + weights, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + weighted_sampler = CustomWeightedRandomSampler( + weights=weights, + num_samples=len(weights), + replacement=False) + + super(DistributedWeightedSampler, self).__init__( + sampler=weighted_sampler, + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) diff --git a/build/lib/datasets/video_dataset.py b/build/lib/datasets/video_dataset.py new file mode 100644 index 00000000..b05cc701 --- /dev/null +++ b/build/lib/datasets/video_dataset.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +import pathlib +import warnings + +from logging import getLogger + +import numpy as np +import pandas as pd + +from decord import VideoReader, cpu + +import torch + +from src.datasets.utils.weighted_sampler import DistributedWeightedSampler + +_GLOBAL_SEED = 0 +logger = getLogger() + + +def make_videodataset( + data_paths, + batch_size, + frames_per_clip=8, + frame_step=4, + num_clips=1, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + transform=None, + shared_transform=None, + rank=0, + world_size=1, + datasets_weights=None, + collator=None, + drop_last=True, + num_workers=10, + pin_mem=True, + duration=None, + log_dir=None, +): + dataset = VideoDataset( + data_paths=data_paths, + datasets_weights=datasets_weights, + frames_per_clip=frames_per_clip, + frame_step=frame_step, + num_clips=num_clips, + random_clip_sampling=random_clip_sampling, + allow_clip_overlap=allow_clip_overlap, + filter_short_videos=filter_short_videos, + filter_long_videos=filter_long_videos, + duration=duration, + shared_transform=shared_transform, + transform=transform) + + logger.info('VideoDataset dataset created') + if datasets_weights is not None: + dist_sampler = DistributedWeightedSampler( + dataset.sample_weights, + num_replicas=world_size, + rank=rank, + shuffle=True) + else: + dist_sampler = torch.utils.data.distributed.DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=True) + + data_loader = torch.utils.data.DataLoader( + dataset, + collate_fn=collator, + sampler=dist_sampler, + batch_size=batch_size, + drop_last=drop_last, + pin_memory=pin_mem, + num_workers=num_workers, + persistent_workers=num_workers > 0) + logger.info('VideoDataset unsupervised data loader created') + + return dataset, data_loader, dist_sampler + + +class VideoDataset(torch.utils.data.Dataset): + """ Video classification dataset. """ + + def __init__( + self, + data_paths, + datasets_weights=None, + frames_per_clip=16, + frame_step=4, + num_clips=1, + transform=None, + shared_transform=None, + random_clip_sampling=True, + allow_clip_overlap=False, + filter_short_videos=False, + filter_long_videos=int(10**9), + duration=None, # duration in seconds + ): + self.data_paths = data_paths + self.datasets_weights = datasets_weights + self.frames_per_clip = frames_per_clip + self.frame_step = frame_step + self.num_clips = num_clips + self.transform = transform + self.shared_transform = shared_transform + self.random_clip_sampling = random_clip_sampling + self.allow_clip_overlap = allow_clip_overlap + self.filter_short_videos = filter_short_videos + self.filter_long_videos = filter_long_videos + self.duration = duration + + if VideoReader is None: + raise ImportError('Unable to import "decord" which is required to read videos.') + + # Load video paths and labels + samples, labels = [], [] + self.num_samples_per_dataset = [] + for data_path in self.data_paths: + + if data_path[-4:] == '.csv': + data = pd.read_csv(data_path, header=None, delimiter=" ") + samples += list(data.values[:, 0]) + labels += list(data.values[:, 1]) + num_samples = len(data) + self.num_samples_per_dataset.append(num_samples) + + elif data_path[-4:] == '.npy': + data = np.load(data_path, allow_pickle=True) + data = list(map(lambda x: repr(x)[1:-1], data)) + samples += data + labels += [0] * len(data) + num_samples = len(data) + self.num_samples_per_dataset.append(len(data)) + + # [Optional] Weights for each sample to be used by downstream + # weighted video sampler + self.sample_weights = None + if self.datasets_weights is not None: + self.sample_weights = [] + for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): + self.sample_weights += [dw / ns] * ns + + self.samples = samples + self.labels = labels + + def __getitem__(self, index): + sample = self.samples[index] + + # Keep trying to load videos until you find a valid sample + loaded_video = False + while not loaded_video: + buffer, clip_indices = self.loadvideo_decord(sample) # [T H W 3] + loaded_video = len(buffer) > 0 + if not loaded_video: + index = np.random.randint(self.__len__()) + sample = self.samples[index] + + # Label/annotations for video + label = self.labels[index] + + def split_into_clips(video): + """ Split video into a list of clips """ + fpc = self.frames_per_clip + nc = self.num_clips + return [video[i*fpc:(i+1)*fpc] for i in range(nc)] + + # Parse video into frames & apply data augmentations + if self.shared_transform is not None: + buffer = self.shared_transform(buffer) + buffer = split_into_clips(buffer) + if self.transform is not None: + buffer = [self.transform(clip) for clip in buffer] + + return buffer, label, clip_indices + + def loadvideo_decord(self, sample): + """ Load video content using Decord """ + + fname = sample + if not os.path.exists(fname): + warnings.warn(f'video path not found {fname=}') + return [], None + + _fsize = os.path.getsize(fname) + if _fsize < 1 * 1024: # avoid hanging issue + warnings.warn(f'video too short {fname=}') + return [], None + if _fsize > self.filter_long_videos: + warnings.warn(f'skipping long video of size {_fsize=} (bytes)') + return [], None + + try: + vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) + except Exception: + return [], None + + fpc = self.frames_per_clip + fstp = self.frame_step + if self.duration is not None: + try: + fps = vr.get_avg_fps() + fstp = int(self.duration * fps / fpc) + except Exception as e: + warnings.warn(e) + clip_len = int(fpc * fstp) + + if self.filter_short_videos and len(vr) < clip_len: + warnings.warn(f'skipping video of length {len(vr)}') + return [], None + + vr.seek(0) # Go to start of video before sampling frames + + # Partition video into equal sized segments and sample each clip + # from a different segment + partition_len = len(vr) // self.num_clips + + all_indices, clip_indices = [], [] + for i in range(self.num_clips): + + if partition_len > clip_len: + # If partition_len > clip len, then sample a random window of + # clip_len frames within the segment + end_indx = clip_len + if self.random_clip_sampling: + end_indx = np.random.randint(clip_len, partition_len) + start_indx = end_indx - clip_len + indices = np.linspace(start_indx, end_indx, num=fpc) + indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) + # -- + indices = indices + i * partition_len + else: + # If partition overlap not allowed and partition_len < clip_len + # then repeatedly append the last frame in the segment until + # we reach the desired clip length + if not self.allow_clip_overlap: + indices = np.linspace(0, partition_len, num=partition_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) + indices = np.clip(indices, 0, partition_len-1).astype(np.int64) + # -- + indices = indices + i * partition_len + + # If partition overlap is allowed and partition_len < clip_len + # then start_indx of segment i+1 will lie within segment i + else: + sample_len = min(clip_len, len(vr)) - 1 + indices = np.linspace(0, sample_len, num=sample_len // fstp) + indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) + indices = np.clip(indices, 0, sample_len-1).astype(np.int64) + # -- + clip_step = 0 + if len(vr) > clip_len: + clip_step = (len(vr) - clip_len) // (self.num_clips - 1) + indices = indices + i * clip_step + + clip_indices.append(indices) + all_indices.extend(list(indices)) + + buffer = vr.get_batch(all_indices).asnumpy() + return buffer, clip_indices + + def __len__(self): + return len(self.samples) diff --git a/build/lib/masks/default.py b/build/lib/masks/default.py new file mode 100644 index 00000000..2810c0a1 --- /dev/null +++ b/build/lib/masks/default.py @@ -0,0 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class DefaultCollator(object): + + def __call__(self, batch): + collated_batch = torch.utils.data.default_collate(batch) + return collated_batch, None, None diff --git a/build/lib/masks/multiblock3d.py b/build/lib/masks/multiblock3d.py new file mode 100644 index 00000000..a7bbc3e1 --- /dev/null +++ b/build/lib/masks/multiblock3d.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +from multiprocessing import Value + +from logging import getLogger + +import torch + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + spatial_pred_mask_scale=m.get('spatial_scale'), + temporal_pred_mask_scale=m.get('temporal_scale'), + aspect_ratio=m.get('aspect_ratio'), + npred=m.get('num_blocks'), + max_context_frames_ratio=m.get('max_temporal_keep', 1.0), + max_keep=m.get('max_keep', None), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + spatial_pred_mask_scale=(0.2, 0.8), + temporal_pred_mask_scale=(1.0, 1.0), + aspect_ratio=(0.3, 3.0), + npred=1, + max_context_frames_ratio=1.0, + max_keep=None, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + + self.aspect_ratio = aspect_ratio + self.spatial_pred_mask_scale = spatial_pred_mask_scale + self.temporal_pred_mask_scale = temporal_pred_mask_scale + self.npred = npred + self.max_context_duration = max(1, int(self.duration * max_context_frames_ratio)) # maximum number of time-steps (frames) spanned by context mask + self.max_keep = max_keep # maximum number of patches to keep in context + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def _sample_block_size( + self, + generator, + temporal_scale, + spatial_scale, + aspect_ratio_scale + ): + # -- Sample temporal block mask scale + _rand = torch.rand(1, generator=generator).item() + min_t, max_t = temporal_scale + temporal_mask_scale = min_t + _rand * (max_t - min_t) + t = max(1, int(self.duration * temporal_mask_scale)) + + # -- Sample spatial block mask scale + _rand = torch.rand(1, generator=generator).item() + min_s, max_s = spatial_scale + spatial_mask_scale = min_s + _rand * (max_s - min_s) + spatial_num_keep = int(self.height * self.width * spatial_mask_scale) + + # -- Sample block aspect-ratio + _rand = torch.rand(1, generator=generator).item() + min_ar, max_ar = aspect_ratio_scale + aspect_ratio = min_ar + _rand * (max_ar - min_ar) + + # -- Compute block height and width (given scale and aspect-ratio) + h = int(round(math.sqrt(spatial_num_keep * aspect_ratio))) + w = int(round(math.sqrt(spatial_num_keep / aspect_ratio))) + h = min(h, self.height) + w = min(w, self.width) + + return (t, h, w) + + def _sample_block_mask(self, b_size): + t, h, w = b_size + top = torch.randint(0, self.height - h + 1, (1,)) + left = torch.randint(0, self.width - w + 1, (1,)) + start = torch.randint(0, self.duration - t + 1, (1,)) + + mask = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + mask[start:start+t, top:top+h, left:left+w] = 0 + + # Context mask will only span the first X frames + # (X=self.max_context_frames) + if self.max_context_duration < self.duration: + mask[self.max_context_duration:, :, :] = 0 + + # -- + return mask + + def __call__(self, batch_size): + """ + Create encoder and predictor masks when collating imgs into a batch + # 1. sample pred block size using seed + # 2. sample several pred block locations for each image (w/o seed) + # 3. return pred masks and complement (enc mask) + """ + seed = self.step() + g = torch.Generator() + g.manual_seed(seed) + p_size = self._sample_block_size( + generator=g, + temporal_scale=self.temporal_pred_mask_scale, + spatial_scale=self.spatial_pred_mask_scale, + aspect_ratio_scale=self.aspect_ratio, + ) + + collated_masks_pred, collated_masks_enc = [], [] + min_keep_enc = min_keep_pred = self.duration * self.height * self.width + for _ in range(batch_size): + + empty_context = True + while empty_context: + + mask_e = torch.ones((self.duration, self.height, self.width), dtype=torch.int32) + for _ in range(self.npred): + mask_e *= self._sample_block_mask(p_size) + mask_e = mask_e.flatten() + + mask_p = torch.argwhere(mask_e == 0).squeeze() + mask_e = torch.nonzero(mask_e).squeeze() + + empty_context = len(mask_e) == 0 + if not empty_context: + min_keep_pred = min(min_keep_pred, len(mask_p)) + min_keep_enc = min(min_keep_enc, len(mask_e)) + collated_masks_pred.append(mask_p) + collated_masks_enc.append(mask_e) + + if self.max_keep is not None: + min_keep_enc = min(min_keep_enc, self.max_keep) + + collated_masks_pred = [cm[:min_keep_pred] for cm in collated_masks_pred] + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + # -- + collated_masks_enc = [cm[:min_keep_enc] for cm in collated_masks_enc] + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + + return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/random_tube.py b/build/lib/masks/random_tube.py new file mode 100644 index 00000000..84c06402 --- /dev/null +++ b/build/lib/masks/random_tube.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from multiprocessing import Value + +from logging import getLogger + +import torch +import numpy as np + +_GLOBAL_SEED = 0 +logger = getLogger() + + +class MaskCollator(object): + + def __init__( + self, + cfgs_mask, + crop_size=(224, 224), + num_frames=16, + patch_size=(16, 16), + tubelet_size=2, + ): + super(MaskCollator, self).__init__() + + self.mask_generators = [] + for m in cfgs_mask: + mask_generator = _MaskGenerator( + crop_size=crop_size, + num_frames=num_frames, + spatial_patch_size=patch_size, + temporal_patch_size=tubelet_size, + ratio=m.get('ratio'), + ) + self.mask_generators.append(mask_generator) + + def step(self): + for mask_generator in self.mask_generators: + mask_generator.step() + + def __call__(self, batch): + + batch_size = len(batch) + collated_batch = torch.utils.data.default_collate(batch) + + collated_masks_pred, collated_masks_enc = [], [] + for i, mask_generator in enumerate(self.mask_generators): + masks_enc, masks_pred = mask_generator(batch_size) + collated_masks_enc.append(masks_enc) + collated_masks_pred.append(masks_pred) + + return collated_batch, collated_masks_enc, collated_masks_pred + + +class _MaskGenerator(object): + + def __init__( + self, + crop_size=(224, 224), + num_frames=16, + spatial_patch_size=(16, 16), + temporal_patch_size=2, + ratio=0.9, + ): + super(_MaskGenerator, self).__init__() + if not isinstance(crop_size, tuple): + crop_size = (crop_size, ) * 2 + self.crop_size = crop_size + self.height, self.width = crop_size[0] // spatial_patch_size, crop_size[1] // spatial_patch_size + self.duration = num_frames // temporal_patch_size + + self.spatial_patch_size = spatial_patch_size + self.temporal_patch_size = temporal_patch_size + self.num_patches_spatial = self.height*self.width + + self.ratio = ratio + + self.num_keep_spatial = int(self.num_patches_spatial*(1.-self.ratio)) + self.num_keep = self.num_keep_spatial * self.duration + + self._itr_counter = Value('i', -1) # collator is shared across worker processes + + def step(self): + i = self._itr_counter + with i.get_lock(): + i.value += 1 + v = i.value + return v + + def __call__(self, batch_size): + def sample_mask(): + mask = np.hstack([ + np.zeros(self.num_patches_spatial - self.num_keep_spatial), + np.ones(self.num_keep_spatial), + ]) + np.random.shuffle(mask) + mask = torch.tensor(np.tile(mask, (self.duration, 1))) + mask = mask.flatten() + mask_p = torch.argwhere(mask == 0).squeeze() + mask_e = torch.nonzero(mask).squeeze() + return mask_e, mask_p + + collated_masks_pred, collated_masks_enc = [], [] + for _ in range(batch_size): + mask_e, mask_p = sample_mask() + collated_masks_enc.append(mask_e) + collated_masks_pred.append(mask_p) + + collated_masks_enc = torch.utils.data.default_collate(collated_masks_enc) + collated_masks_pred = torch.utils.data.default_collate(collated_masks_pred) + + return collated_masks_enc, collated_masks_pred diff --git a/build/lib/masks/utils.py b/build/lib/masks/utils.py new file mode 100644 index 00000000..ca04af1f --- /dev/null +++ b/build/lib/masks/utils.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch + + +def apply_masks(x, masks, concat=True): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors of shape [B, K] containing indices of K patches in [N] to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + if not concat: + return all_x + + return torch.cat(all_x, dim=0) diff --git a/build/lib/models/attentive_pooler.py b/build/lib/models/attentive_pooler.py new file mode 100644 index 00000000..ecd9986a --- /dev/null +++ b/build/lib/models/attentive_pooler.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch +import torch.nn as nn + +from src.models.utils.modules import ( + Block, + CrossAttention, + CrossAttentionBlock +) +from src.utils.tensors import trunc_normal_ + + +class AttentivePooler(nn.Module): + """ Attentive Pooler """ + def __init__( + self, + num_queries=1, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + complete_block=True + ): + super().__init__() + self.query_tokens = nn.Parameter(torch.zeros(1, num_queries, embed_dim)) + + self.complete_block = complete_block + if complete_block: + self.cross_attention_block = CrossAttentionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer) + else: + self.cross_attention_block = CrossAttention( + dim=embed_dim, + num_heads=num_heads, + qkv_bias=qkv_bias) + + self.blocks = None + if depth > 1: + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=False, + norm_layer=norm_layer) + for i in range(depth-1)]) + + self.init_std = init_std + trunc_normal_(self.query_tokens, std=self.init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + if self.complete_block: + rescale(self.cross_attention_block.xattn.proj.weight.data, 1) + rescale(self.cross_attention_block.mlp.fc2.weight.data, 1) + else: + rescale(self.cross_attention_block.proj.weight.data, 1) + if self.blocks is not None: + for layer_id, layer in enumerate(self.blocks, 1): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + q = self.query_tokens.repeat(len(x), 1, 1) + q = self.cross_attention_block(q, x) + if self.blocks is not None: + for blk in self.blocks: + q = blk(q) + return q + + +class AttentiveClassifier(nn.Module): + """ Attentive Classifier """ + def __init__( + self, + embed_dim=768, + num_heads=12, + mlp_ratio=4.0, + depth=1, + norm_layer=nn.LayerNorm, + init_std=0.02, + qkv_bias=True, + num_classes=1000, + complete_block=True, + ): + super().__init__() + self.pooler = AttentivePooler( + num_queries=1, + embed_dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + depth=depth, + norm_layer=norm_layer, + init_std=init_std, + qkv_bias=qkv_bias, + complete_block=complete_block, + ) + self.linear = nn.Linear(embed_dim, num_classes, bias=True) + + def forward(self, x): + x = self.pooler(x).squeeze(1) + x = self.linear(x) + return x diff --git a/build/lib/models/predictor.py b/build/lib/models/predictor.py new file mode 100644 index 00000000..2dd9a38b --- /dev/null +++ b/build/lib/models/predictor.py @@ -0,0 +1,246 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from src.models.utils.modules import Block +from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from src.utils.tensors import ( + trunc_normal_, + repeat_interleave_batch +) +from src.masks.utils import apply_masks + + +class VisionTransformerPredictor(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + embed_dim=768, + predictor_embed_dim=384, + depth=6, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + uniform_power=False, + use_mask_tokens=False, + num_mask_tokens=2, + zero_init_mask_tokens=True, + **kwargs + ): + super().__init__() + # Map input to predictor dimension + self.predictor_embed = nn.Linear(embed_dim, predictor_embed_dim, bias=True) + + # Mask tokens + self.mask_tokens = None + self.num_mask_tokens = 0 + if use_mask_tokens: + self.num_mask_tokens = num_mask_tokens + self.mask_tokens = nn.ParameterList([ + nn.Parameter(torch.zeros(1, 1, predictor_embed_dim)) + for i in range(num_mask_tokens) + ]) + + # Determine positional embedding + self.input_size = img_size + self.patch_size = patch_size + # -- + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + if self.is_video: + self.num_patches = num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.num_patches = num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + # Position embedding + self.uniform_power = uniform_power + self.predictor_pos_embed = None + self.predictor_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, predictor_embed_dim), + requires_grad=False) + + # Attention Blocks + self.predictor_blocks = nn.ModuleList([ + Block( + dim=predictor_embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + attn_drop=attn_drop_rate, + grid_size=grid_size, + grid_depth=grid_depth, + norm_layer=norm_layer) + for i in range(depth)]) + + # Normalize & project back to input dimension + self.predictor_norm = norm_layer(predictor_embed_dim) + self.predictor_proj = nn.Linear(predictor_embed_dim, embed_dim, bias=True) + + # ------ initialize weights + if self.predictor_pos_embed is not None: + self._init_pos_embed(self.predictor_pos_embed.data) # sincos pos-embed + self.init_std = init_std + if not zero_init_mask_tokens: + for mt in self.mask_tokens: + trunc_normal_(mt, std=init_std) + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.predictor_blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def diffusion(self, x, noise_beta=(0.5, 1.0), steps=1000): + + # Prepare diffusion noise schedule + b1, b2 = noise_beta + beta_scheduler = (b1 + i*(b2-b1)/steps for i in range(steps)) + alpha_scheduler = [] + _alpha = 1.0 + for _beta in beta_scheduler: + _alpha *= 1.-_beta + alpha_scheduler += [_alpha] + + # Sample diffusion time step + T = torch.randint(0, steps, (len(x),)) + alpha = torch.tensor(alpha_scheduler, device=x.device)[T].unsqueeze(-1).unsqueeze(-1) + + # Normalize features and apply noise + x = torch.nn.functional.layer_norm(x, (x.size(-1),)) + x = alpha**0.5 * x + (1.-alpha)**0.5 * torch.randn(x.shape, device=x.device) + return x + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt, mask_index=1): + """ + :param ctxt: context tokens + :param tgt: target tokens + :param masks_ctxt: indices of context tokens in input + :params masks_tgt: indices of target tokens in input + """ + + assert (masks_ctxt is not None) and (masks_tgt is not None), 'Cannot run predictor without mask indices' + + if not isinstance(masks_ctxt, list): + masks_ctxt = [masks_ctxt] + + if not isinstance(masks_tgt, list): + masks_tgt = [masks_tgt] + + # Batch Size + B = len(ctxt) // len(masks_ctxt) + + # Map context tokens to pedictor dimensions + x = self.predictor_embed(ctxt) + _, N_ctxt, D = x.shape + + # Add positional embedding to ctxt tokens + if self.predictor_pos_embed is not None: + ctxt_pos_embed = self.predictor_pos_embed.repeat(B, 1, 1) + x += apply_masks(ctxt_pos_embed, masks_ctxt) + + # Map target tokens to predictor dimensions & add noise (fwd diffusion) + if self.mask_tokens is None: + pred_tokens = self.predictor_embed(tgt) + pred_tokens = self.diffusion(pred_tokens) + else: + mask_index = mask_index % self.num_mask_tokens + pred_tokens = self.mask_tokens[mask_index] + pred_tokens = pred_tokens.repeat(B, self.num_patches, 1) + pred_tokens = apply_masks(pred_tokens, masks_tgt) + + # Add positional embedding to target tokens + if self.predictor_pos_embed is not None: + pos_embs = self.predictor_pos_embed.repeat(B, 1, 1) + pos_embs = apply_masks(pos_embs, masks_tgt) + pos_embs = repeat_interleave_batch(pos_embs, B, repeat=len(masks_ctxt)) + pred_tokens += pos_embs + + # Concatenate context & target tokens + x = x.repeat(len(masks_tgt), 1, 1) + x = torch.cat([x, pred_tokens], dim=1) + + # FIXME: this implementation currently assumes masks_ctxt and masks_tgt + # are alligned 1:1 (ok with MultiMask wrapper on predictor but + # otherwise will break) + masks_ctxt = torch.cat(masks_ctxt, dim=0) + masks_tgt = torch.cat(masks_tgt, dim=0) + masks = torch.cat([masks_ctxt, masks_tgt], dim=1) + + # Fwd prop + for blk in self.predictor_blocks: + x = blk(x, mask=masks) + x = self.predictor_norm(x) + + # Return output corresponding to target tokens + x = x[:, N_ctxt:] + x = self.predictor_proj(x) + + return x + + +def vit_predictor(**kwargs): + model = VisionTransformerPredictor( + mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), + **kwargs) + return model diff --git a/build/lib/models/utils/modules.py b/build/lib/models/utils/modules.py new file mode 100644 index 00000000..dc470d9b --- /dev/null +++ b/build/lib/models/utils/modules.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0. + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + proj_drop=0., + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop_prob = proj_drop + self.proj_drop = nn.Dropout(proj_drop) + self.use_sdpa = use_sdpa + + def forward(self, x, mask=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, N, D] + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.proj_drop_prob) + attn = None + else: + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, D, D] + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + qk_scale=None, + drop=0., + attn_drop=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + grid_size=None, + grid_depth=None, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + def forward(self, x, return_attention=False, mask=None): + y, attn = self.attn(self.norm1(x), mask=mask) + if return_attention: + return attn + x = x + y + x = x + self.mlp(self.norm2(x)) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + dim, + num_heads=12, + qkv_bias=False, + use_sdpa=True + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.q = nn.Linear(dim, dim, bias=qkv_bias) + self.kv = nn.Linear(dim, int(dim*2), bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + self.use_sdpa = use_sdpa + + def forward(self, q, x): + B, n, C = q.shape + q = self.q(q).reshape(B, n, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + B, N, C = x.shape + kv = self.kv(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] # (batch_size, num_heads, seq_len, feature_dim_per_head) + + if self.use_sdpa: + with torch.backends.cuda.sdp_kernel(): + q = F.scaled_dot_product_attention(q, k, v) + else: + xattn = (q @ k.transpose(-2, -1)) * self.scale + xattn = xattn.softmax(dim=-1) # (batch_size, num_heads, query_len, seq_len) + q = (xattn @ v) + + q = q.transpose(1, 2).reshape(B, n, C) + q = self.proj(q) + + return q + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.xattn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + + def forward(self, q, x): + y = self.xattn(q, self.norm1(x)) + q = q + y + q = q + self.mlp(self.norm2(q)) + return q diff --git a/build/lib/models/utils/multimask.py b/build/lib/models/utils/multimask.py new file mode 100644 index 00000000..d4800869 --- /dev/null +++ b/build/lib/models/utils/multimask.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class MultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, x, masks=None): + if masks is None: + return self.backbone(x) + + if (masks is not None) and not isinstance(masks, list): + masks = [masks] + outs = [] + for m in masks: + outs += [self.backbone(x, masks=m)] + return outs + + +class PredictorMultiMaskWrapper(nn.Module): + + def __init__(self, backbone): + super().__init__() + self.backbone = backbone + + def forward(self, ctxt, tgt, masks_ctxt, masks_tgt): + if type(ctxt) is not list: + ctxt = [ctxt] + if type(tgt) is not list: + tgt = [tgt] + if type(masks_ctxt) is not list: + masks_ctxt = [masks_ctxt] + if type(masks_tgt) is not list: + masks_tgt = [masks_tgt] + + outs = [] + for i, (zi, hi, mc, mt) in enumerate(zip(ctxt, tgt, masks_ctxt, masks_tgt)): + outs += [self.backbone(zi, hi, mc, mt, mask_index=i)] + return outs diff --git a/build/lib/models/utils/patch_embed.py b/build/lib/models/utils/patch_embed.py new file mode 100644 index 00000000..4ff4de51 --- /dev/null +++ b/build/lib/models/utils/patch_embed.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import torch.nn as nn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=768 + ): + super().__init__() + self.patch_size = patch_size + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PatchEmbed3D(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__( + self, + patch_size=16, + tubelet_size=2, + in_chans=3, + embed_dim=768, + ): + super().__init__() + self.patch_size = patch_size + self.tubelet_size = tubelet_size + + self.proj = nn.Conv3d( + in_channels=in_chans, + out_channels=embed_dim, + kernel_size=(tubelet_size, patch_size, patch_size), + stride=(tubelet_size, patch_size, patch_size), + ) + + def forward(self, x, **kwargs): + B, C, T, H, W = x.shape + x = self.proj(x).flatten(2).transpose(1, 2) + return x diff --git a/build/lib/models/utils/pos_embs.py b/build/lib/models/utils/pos_embs.py new file mode 100644 index 00000000..d1d82e21 --- /dev/null +++ b/build/lib/models/utils/pos_embs.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import numpy as np + + +def get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=False +): + """ + grid_size: int of the grid height and width + grid_depth: int of the grid depth + returns: + pos_embed: [grid_depth*grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_depth*grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_d = np.arange(grid_depth, dtype=float) + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_h, grid_d, grid_w = np.meshgrid(grid_h, grid_d, grid_w) # order of meshgrid is very important for indexing as [d,h,w] + + if not uniform_power: + h_embed_dim = embed_dim // 4 + w_embed_dim = embed_dim // 4 + d_embed_dim = embed_dim // 2 + else: + h_embed_dim = w_embed_dim = d_embed_dim = int(np.ceil(embed_dim/6)*2) + + emb_h = get_1d_sincos_pos_embed_from_grid(h_embed_dim, grid_h) # (T*H*W, D1) + emb_w = get_1d_sincos_pos_embed_from_grid(w_embed_dim, grid_w) # (T*H*W, D2) + emb_d = get_1d_sincos_pos_embed_from_grid(d_embed_dim, grid_d) # (T*H*W, D3) + pos_embed = np.concatenate([emb_d, emb_h, emb_w], axis=1) + pos_embed = pos_embed[:, :embed_dim] + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + returns: + pos_embed: [grid_size*grid_size, embed_dim] (w/o cls_token) + or [1+grid_size*grid_size, embed_dim] (w/ cls_token) + """ + grid_h = np.arange(grid_size, dtype=float) + grid_w = np.arange(grid_size, dtype=float) + grid_w, grid_h = np.meshgrid(grid_w, grid_h) # order of meshgrid is very important for indexing as [h, w] + + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_h) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid_w) # (H*W, D/2) + pos_embed = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + embed_dim: output dimension for each position + grid_size: int of the grid length + returns: + pos_embed: [grid_size, embed_dim] (w/o cls_token) + or [1+grid_size, embed_dim] (w/ cls_token) + """ + grid = np.arange(grid_size, dtype=float) + pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + returns: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/build/lib/models/vision_transformer.py b/build/lib/models/vision_transformer.py new file mode 100644 index 00000000..a8748dfd --- /dev/null +++ b/build/lib/models/vision_transformer.py @@ -0,0 +1,307 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math +from functools import partial + +import torch +import torch.nn as nn + +from src.models.utils.patch_embed import PatchEmbed, PatchEmbed3D +from src.models.utils.modules import Block +from src.models.utils.pos_embs import get_2d_sincos_pos_embed, get_3d_sincos_pos_embed +from src.utils.tensors import trunc_normal_ +from src.masks.utils import apply_masks + + +class VisionTransformer(nn.Module): + """ Vision Transformer """ + def __init__( + self, + img_size=224, + patch_size=16, + num_frames=1, + tubelet_size=2, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + norm_layer=nn.LayerNorm, + init_std=0.02, + out_layers=None, + uniform_power=False, + **kwargs + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim + self.num_heads = num_heads + self.out_layers = out_layers + + self.input_size = img_size + self.patch_size = patch_size + + self.num_frames = num_frames + self.tubelet_size = tubelet_size + self.is_video = num_frames > 1 + + grid_size = self.input_size // self.patch_size + grid_depth = self.num_frames // self.tubelet_size + + # Tokenize pixels with convolution + if self.is_video: + self.patch_embed = PatchEmbed3D( + patch_size=patch_size, + tubelet_size=tubelet_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (num_frames // tubelet_size) + * (img_size // patch_size) + * (img_size // patch_size) + ) + else: + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim) + self.num_patches = ( + (img_size // patch_size) + * (img_size // patch_size) + ) + + # Position embedding + self.uniform_power = uniform_power + self.pos_embed = None + self.pos_embed = nn.Parameter( + torch.zeros(1, self.num_patches, embed_dim), + requires_grad=False) + + # Attention Blocks + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + act_layer=nn.GELU, + grid_size=grid_size, + grid_depth=grid_depth, + attn_drop=attn_drop_rate, + norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + + # ------ initialize weights + if self.pos_embed is not None: + self._init_pos_embed(self.pos_embed.data) # sincos pos-embed + self.init_std = init_std + self.apply(self._init_weights) + self._rescale_blocks() + + def _init_pos_embed(self, pos_embed): + embed_dim = pos_embed.size(-1) + grid_size = self.input_size // self.patch_size + if self.is_video: + grid_depth = self.num_frames // self.tubelet_size + sincos = get_3d_sincos_pos_embed( + embed_dim, + grid_size, + grid_depth, + cls_token=False, + uniform_power=self.uniform_power + ) + else: + sincos = get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False) + pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=self.init_std) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv3d): + trunc_normal_(m.weight, std=self.init_std) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _rescale_blocks(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def get_num_layers(self): + return len(self.blocks) + + def no_weight_decay(self): + return {} + + def forward(self, x, masks=None): + """ + :param x: input image/video + :param masks: indices of patch tokens to mask (remove) + """ + + if masks is not None and not isinstance(masks, list): + masks = [masks] + + # Tokenize input + pos_embed = self.pos_embed + if pos_embed is not None: + pos_embed = self.interpolate_pos_encoding(x, pos_embed) + x = self.patch_embed(x) + if pos_embed is not None: + x += pos_embed + B, N, D = x.shape + + # Mask away unwanted tokens (if masks provided) + if masks is not None: + x = apply_masks(x, masks) + masks = torch.cat(masks, dim=0) + + # Fwd prop + outs = [] + for i, blk in enumerate(self.blocks): + x = blk(x, mask=masks) + if self.out_layers is not None and i in self.out_layers: + outs.append(self.norm(x)) + + if self.out_layers is not None: + return outs + + if self.norm is not None: + x = self.norm(x) + + return x + + def interpolate_pos_encoding(self, x, pos_embed): + + _, N, dim = pos_embed.shape + + if self.is_video: + + # If pos_embed already corret size, just return + _, _, T, H, W = x.shape + if H == self.input_size and W == self.input_size and T == self.num_frames: + return pos_embed + + # Convert depth, height, width of input to be measured in patches + # instead of pixels/frames + T = T // self.tubelet_size + H = H // self.patch_size + W = W // self.patch_size + + # Compute the initialized shape of the positional embedding measured + # in patches + N_t = self.num_frames // self.tubelet_size + N_h = N_w = self.input_size // self.patch_size + assert N_h * N_w * N_t == N, 'Positional embedding initialized incorrectly' + + # Compute scale factor for spatio-temporal interpolation + scale_factor = (T/N_t, H/N_h, W/N_w) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, N_t, N_h, N_w, dim).permute(0, 4, 1, 2, 3), + scale_factor=scale_factor, + mode='trilinear') + pos_embed = pos_embed.permute(0, 2, 3, 4, 1).view(1, -1, dim) + return pos_embed + + else: + + # If pos_embed already corret size, just return + _, _, H, W = x.shape + if H == self.input_size and W == self.input_size: + return pos_embed + + # Compute scale factor for spatial interpolation + npatch = (H // self.patch_size) * (W // self.patch_size) + scale_factor = math.sqrt(npatch / N) + + pos_embed = nn.functional.interpolate( + pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=scale_factor, + mode='bicubic') + pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return pos_embed + + +def vit_tiny(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_small(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_base(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_large(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_huge(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_giant(patch_size=16, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_gigantic(patch_size=14, **kwargs): + model = VisionTransformer( + patch_size=patch_size, embed_dim=1664, depth=48, num_heads=16, mpl_ratio=64/13, + qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs + ) + return model + + +VIT_EMBED_DIMS = { + 'vit_tiny': 192, + 'vit_small': 384, + 'vit_base': 768, + 'vit_large': 1024, + 'vit_huge': 1280, + 'vit_giant': 1408, + 'vit_gigantic': 1664, +} diff --git a/build/lib/utils/distributed.py b/build/lib/utils/distributed.py new file mode 100644 index 00000000..cfba444d --- /dev/null +++ b/build/lib/utils/distributed.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import os + +import torch +import torch.distributed as dist + +from logging import getLogger + +logger = getLogger() + + +def init_distributed(port=37123, rank_and_world_size=(None, None)): + + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size(), dist.get_rank() + + rank, world_size = rank_and_world_size + os.environ['MASTER_ADDR'] = 'localhost' + + if (rank is None) or (world_size is None): + try: + world_size = int(os.environ['SLURM_NTASKS']) + rank = int(os.environ['SLURM_PROCID']) + os.environ['MASTER_ADDR'] = os.environ['HOSTNAME'] + except Exception: + logger.info('SLURM vars not set (distributed training not available)') + world_size, rank = 1, 0 + return world_size, rank + + try: + os.environ['MASTER_PORT'] = str(port) + torch.distributed.init_process_group( + backend='nccl', + world_size=world_size, + rank=rank + ) + except Exception as e: + world_size, rank = 1, 0 + logger.info(f'Rank: {rank}. Distributed training not available {e}') + + return world_size, rank + + +class AllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + outputs = [torch.zeros_like(x) for _ in range(dist.get_world_size())] + dist.all_gather(outputs, x) + return torch.cat(outputs, 0) + return x + + @staticmethod + def backward(ctx, grads): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + s = (grads.shape[0] // dist.get_world_size()) * dist.get_rank() + e = (grads.shape[0] // dist.get_world_size()) * (dist.get_rank() + 1) + grads = grads.contiguous() + dist.all_reduce(grads) + return grads[s:e] + return grads + + +class AllReduceSum(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads + + +class AllReduce(torch.autograd.Function): + + @staticmethod + def forward(ctx, x): + if ( + dist.is_available() + and dist.is_initialized() + and (dist.get_world_size() > 1) + ): + x = x.contiguous() / dist.get_world_size() + dist.all_reduce(x) + return x + + @staticmethod + def backward(ctx, grads): + return grads diff --git a/build/lib/utils/logging.py b/build/lib/utils/logging.py new file mode 100644 index 00000000..fcdd3faf --- /dev/null +++ b/build/lib/utils/logging.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import logging +import sys + +import torch + + +def gpu_timer(closure, log_timings=True): + """ Helper to time gpu-time to execute closure() """ + log_timings = log_timings and torch.cuda.is_available() + + elapsed_time = -1. + if log_timings: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + + result = closure() + + if log_timings: + end.record() + torch.cuda.synchronize() + elapsed_time = start.elapsed_time(end) + + return result, elapsed_time + + +LOG_FORMAT = "[%(levelname)-8s][%(asctime)s][%(funcName)-25s] %(message)s" +DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + + +def get_logger(name=None, force=False): + logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=LOG_FORMAT, datefmt=DATE_FORMAT, force=force) + return logging.getLogger(name=name) + + +class CSVLogger(object): + + def __init__(self, fname, *argv): + self.fname = fname + self.types = [] + # -- print headers + with open(self.fname, '+a') as f: + for i, v in enumerate(argv, 1): + self.types.append(v[0]) + if i < len(argv): + print(v[1], end=',', file=f) + else: + print(v[1], end='\n', file=f) + + def log(self, *argv): + with open(self.fname, '+a') as f: + for i, tv in enumerate(zip(self.types, argv), 1): + end = ',' if i < len(argv) else '\n' + print(tv[0] % tv[1], end=end, file=f) + + +class AverageMeter(object): + """computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.max = float('-inf') + self.min = float('inf') + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + try: + self.max = max(val, self.max) + self.min = min(val, self.min) + except Exception: + pass + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def grad_logger(named_params): + stats = AverageMeter() + stats.first_layer = None + stats.last_layer = None + for n, p in named_params: + if (p.grad is not None) and not (n.endswith('.bias') or len(p.shape) == 1): + grad_norm = float(torch.norm(p.grad.data)) + stats.update(grad_norm) + if 'qkv' in n: + stats.last_layer = grad_norm + if stats.first_layer is None: + stats.first_layer = grad_norm + if stats.first_layer is None or stats.last_layer is None: + stats.first_layer = stats.last_layer = 0. + return stats + + +def adamw_logger(optimizer): + """ logging magnitude of first and second momentum buffers in adamw """ + # TODO: assert that optimizer is instance of torch.optim.AdamW + state = optimizer.state_dict().get('state') + exp_avg_stats = AverageMeter() + exp_avg_sq_stats = AverageMeter() + for key in state: + s = state.get(key) + exp_avg_stats.update(float(s.get('exp_avg').abs().mean())) + exp_avg_sq_stats.update(float(s.get('exp_avg_sq').abs().mean())) + return {'exp_avg': exp_avg_stats, 'exp_avg_sq': exp_avg_sq_stats} diff --git a/build/lib/utils/monitoring.py b/build/lib/utils/monitoring.py new file mode 100644 index 00000000..95a7845a --- /dev/null +++ b/build/lib/utils/monitoring.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import dataclasses +import threading +from typing import Dict, Tuple + +import psutil + + +@dataclasses.dataclass +class ResourceStatsSample: + timestamp: float + cpu_percent: float + read_count: int + write_count: int + read_bytes: int + write_bytes: int + read_chars: int + write_chars: int + cpu_times_user: float + cpu_times_system: float + cpu_times_children_user: float + cpu_times_children_system: float + cpu_times_iowait: float + cpu_affinity: str + cpu_num: int + num_threads: int + num_voluntary_ctx_switches: int + num_involuntary_ctx_switches: int + + def as_tuple(self) -> Dict: + """Return values mirroring fields.""" + return dataclasses.astuple(self) + + def fields(self) -> Tuple[dataclasses.Field, ...]: + """Return fields in this dataclass.""" + return dataclasses.fields(self.__class__) + + +class ResourceMonitoringThread(threading.Thread): + def __init__(self, pid=None, refresh_interval=None, stats_callback_fn=None): + """Starts a thread to monitor pid every refresh_interval seconds. + + Passes a ResourceStatsSample object to the callback.""" + super(ResourceMonitoringThread, self).__init__() + if refresh_interval is None: + refresh_interval = 5 + self.is_running_event = threading.Event() + self.p = psutil.Process(pid) + self.refresh_interval = refresh_interval + if stats_callback_fn is None: + # Default callback + def stats_callback_fn(resource_sample: ResourceStatsSample): + print( + f"PID {self.p.pid} Stats: {resource_sample.resource_stats}") + elif not callable(stats_callback_fn): + raise ValueError("Callback needs to be callable, got {}".format( + type(stats_callback_fn))) + self.stats_callback_fn = stats_callback_fn + + def stop(self) -> None: + self.is_running_event.set() + + def run(self) -> None: + while not self.is_running_event.is_set(): + self.sample_counters() + self.is_running_event.wait(self.refresh_interval) + + def log_sample(self, resource_sample: ResourceStatsSample) -> None: + self.stats_callback_fn(resource_sample) + + def sample_counters(self) -> None: + if not self.p.is_running(): + self.stop() + return + + with self.p.oneshot(): + cpu_percent = self.p.cpu_percent() + cpu_times = self.p.cpu_times() + io_counters = self.p.io_counters() + cpu_affinity = self.p.cpu_affinity() + cpu_num = self.p.cpu_num() + num_threads = self.p.num_threads() + num_ctx_switches = self.p.num_ctx_switches() + timestamp = time.time() + + read_count = io_counters.read_count + write_count = io_counters.write_count + read_bytes = io_counters.read_bytes + write_bytes = io_counters.write_bytes + read_chars = io_counters.read_chars + write_chars = io_counters.write_chars + + def compress_cpu_affinity(cpu_affinity): + """Change list representation to interval/range representation.""" + if not cpu_affinity: + return "" + cpu_affinity_compressed = [] + min_x = None + max_x = None + last_x = None + + # Find contiguous ranges + for x in cpu_affinity: + if last_x is None: + # Start interval + min_x = x + max_x = x + last_x = x + continue + elif x == (last_x + 1): + # Move interval up + max_x = x + elif max_x is not None: + # Interval ended, start again + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + min_x = x + max_x = x + last_x = x + # Terminate last range + if max_x is not None: + if min_x == max_x: + cpu_affinity_compressed.append("{}".format(min_x)) + else: + cpu_affinity_compressed.append( + "{}-{}".format(min_x, max_x)) + + # Concat + cpu_affinity_compressed = ",".join(cpu_affinity_compressed) + + return cpu_affinity_compressed + + cpu_affinity = compress_cpu_affinity(cpu_affinity) + + resource_sample = ResourceStatsSample( + timestamp=timestamp, + cpu_percent=cpu_percent, + read_count=read_count, + write_count=write_count, + read_bytes=read_bytes, + write_bytes=write_bytes, + read_chars=read_chars, + write_chars=write_chars, + cpu_times_user=cpu_times.user, + cpu_times_system=cpu_times.system, + cpu_times_children_user=cpu_times.children_user, + cpu_times_children_system=cpu_times.children_system, + cpu_times_iowait=cpu_times.iowait, + cpu_affinity=cpu_affinity, + cpu_num=cpu_num, + num_threads=num_threads, + num_voluntary_ctx_switches=num_ctx_switches.voluntary, + num_involuntary_ctx_switches=num_ctx_switches.involuntary, + ) + self.log_sample(resource_sample) + + +if __name__ == "__main__": + import multiprocessing + import time + pid = multiprocessing.current_process().pid + monitor_thread = ResourceMonitoringThread(pid, 1) + monitor_thread.start() + time.sleep(5) + print("Shutdown") + monitor_thread.stop() diff --git a/build/lib/utils/schedulers.py b/build/lib/utils/schedulers.py new file mode 100644 index 00000000..df02e2b0 --- /dev/null +++ b/build/lib/utils/schedulers.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + + +class WarmupCosineSchedule(object): + + def __init__( + self, + optimizer, + warmup_steps, + start_lr, + ref_lr, + T_max, + last_epoch=-1, + final_lr=0. + ): + self.optimizer = optimizer + self.start_lr = start_lr + self.ref_lr = ref_lr + self.final_lr = final_lr + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + self._step = 0. + + def step(self): + self._step += 1 + if self._step < self.warmup_steps: + progress = float(self._step) / float(max(1, self.warmup_steps)) + new_lr = self.start_lr + progress * (self.ref_lr - self.start_lr) + else: + # -- progress after warmup + progress = float(self._step - self.warmup_steps) / float(max(1, self.T_max)) + new_lr = max(self.final_lr, + self.final_lr + (self.ref_lr - self.final_lr) * 0.5 * (1. + math.cos(math.pi * progress))) + + for group in self.optimizer.param_groups: + group['lr'] = new_lr + + return new_lr + + +class CosineWDSchedule(object): + + def __init__( + self, + optimizer, + ref_wd, + T_max, + final_wd=0. + ): + self.optimizer = optimizer + self.ref_wd = ref_wd + self.final_wd = final_wd + self.T_max = T_max + self._step = 0. + + def step(self): + self._step += 1 + progress = self._step / self.T_max + new_wd = self.final_wd + (self.ref_wd - self.final_wd) * 0.5 * (1. + math.cos(math.pi * progress)) + + if self.final_wd <= self.ref_wd: + new_wd = max(self.final_wd, new_wd) + else: + new_wd = min(self.final_wd, new_wd) + + for group in self.optimizer.param_groups: + if ('WD_exclude' not in group) or not group['WD_exclude']: + group['weight_decay'] = new_wd + return new_wd diff --git a/build/lib/utils/tensors.py b/build/lib/utils/tensors.py new file mode 100644 index 00000000..6ae28509 --- /dev/null +++ b/build/lib/utils/tensors.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +import math + +import torch + +from logging import getLogger + +logger = getLogger() + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def apply_masks(x, masks): + """ + :param x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)] + :param masks: list of tensors containing indices of patches [0,N) to keep + """ + all_x = [] + for m in masks: + mask_keep = m.unsqueeze(-1).repeat(1, 1, x.size(-1)) + all_x += [torch.gather(x, dim=1, index=mask_keep)] + return torch.cat(all_x, dim=0) + + +def repeat_interleave_batch(x, B, repeat): + N = len(x) // B + x = torch.cat([ + torch.cat([x[i*B:(i+1)*B] for _ in range(repeat)], dim=0) + for i in range(N) + ], dim=0) + return x diff --git a/configs/evals/vith16_k700_16x8x3.yaml b/configs/evals/vith16_k700_16x8x3.yaml new file mode 100644 index 00000000..0f9bda1c --- /dev/null +++ b/configs/evals/vith16_k700_16x8x3.yaml @@ -0,0 +1,39 @@ +nodes: 8 +tasks_per_node: 8 +tag: k700-16x8x3 +eval_name: video_classification_frozen +resume_checkpoint: false +data: + dataset_train: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/data_csv/k700_train.csv + dataset_val: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/data_csv/k700_val.csv + dataset_type: VideoDataset + num_classes: 700 + frames_per_clip: 16 + num_segments: 8 + num_views_per_segment: 3 + frame_step: 4 +optimization: + attend_across_segments: true + num_epochs: 20 + resolution: 224 + batch_size: 4 + weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + use_bfloat16: true +pretrain: + model_name: vit_huge + checkpoint_key: target_encoder + clip_duration: null + frames_per_clip: 16 + tubelet_size: 2 + uniform_power: true + use_silu: false + tight_silu: false + use_sdpa: true + patch_size: 16 + folder: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/logs2/ + checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder + write_tag: jepa diff --git a/configs/evals/vitl16_k700_16x8x3.yaml b/configs/evals/vitl16_k700_16x8x3.yaml new file mode 100644 index 00000000..f6198cce --- /dev/null +++ b/configs/evals/vitl16_k700_16x8x3.yaml @@ -0,0 +1,39 @@ +nodes: 8 +tasks_per_node: 8 +tag: k700-16x8x3 +eval_name: video_classification_frozen +resume_checkpoint: false +data: + dataset_train: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/data_csv/k700_train_2.csv + dataset_val: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/data_csv/k700_val_2.csv + dataset_type: VideoDataset + num_classes: 700 + frames_per_clip: 16 + num_segments: 8 + num_views_per_segment: 3 + frame_step: 4 +optimization: + attend_across_segments: true + num_epochs: 20 + resolution: 224 + batch_size: 4 + weight_decay: 0.01 + lr: 0.001 + start_lr: 0.001 + final_lr: 0.0 + warmup: 0. + use_bfloat16: true +pretrain: + model_name: vit_large + checkpoint_key: target_encoder + clip_duration: null + frames_per_clip: 16 + tubelet_size: 2 + uniform_power: true + use_silu: false + tight_silu: false + use_sdpa: true + patch_size: 16 + folder: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/log_k700/exp2_e300/ + checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder + write_tag: jepa diff --git a/configs/pretrain/vitl16.yaml b/configs/pretrain/vitl16.yaml index 4996b9de..066705a4 100644 --- a/configs/pretrain/vitl16.yaml +++ b/configs/pretrain/vitl16.yaml @@ -4,9 +4,7 @@ tasks_per_node: 8 data: dataset_type: VideoDataset datasets: - - /your_path_to_kinetics710_csv_file_index.csv - - /your_path_to_ssv2_csv_file_index.csv - - /your_path_to_howto100m_csv_file_index.csv + - /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/data_csv/k700_train_2.csv decode_one_clip: true batch_size: 24 num_clips: 1 @@ -30,7 +28,7 @@ data_aug: - 1.0 reprob: 0.0 logging: - folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ + folder: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/log_k700/exp2_e300/ write_tag: jepa loss: loss_exp: 1.0 @@ -75,7 +73,7 @@ model: use_mask_tokens: true zero_init_mask_tokens: true optimization: - ipe: 300 + ipe: 1 ipe_scale: 1.25 clip_grad: 10.0 weight_decay: 0.04 diff --git a/configs/pretrain/vitl16_mgh.yaml b/configs/pretrain/vitl16_mgh.yaml new file mode 100644 index 00000000..595660c6 --- /dev/null +++ b/configs/pretrain/vitl16_mgh.yaml @@ -0,0 +1,89 @@ +app: vjepa +nodes: 16 +tasks_per_node: 8 +data: + dataset_type: VideoDataset + datasets: + # - /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/data_csv/k700_train_2.csv + - '/beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/data_csv/MGH_train.csv' + decode_one_clip: true + batch_size: 3 + num_clips: 1 + num_frames: 32 + tubelet_size: 2 + sampling_rate: 2 + crop_size: 224 + patch_size: 16 + pin_mem: true + num_workers: 12 + filter_short_videos: false + clip_duration: null +data_aug: + auto_augment: true + motion_shift: false + random_resize_aspect_ratio: + - 0.75 + - 1.35 + random_resize_scale: + - 0.3 + - 1.0 + reprob: 0.0 +logging: + folder: /beacon/data01/chengjie.zheng001/Projects/MGH/umb-jepa/log_mgh/exp1/ + write_tag: jepa +loss: + loss_exp: 1.0 + reg_coeff: 0.0 +mask: + - aspect_ratio: + - 0.75 + - 1.5 + num_blocks: 8 + spatial_scale: + - 0.15 + - 0.15 + temporal_scale: + - 1.0 + - 1.0 + max_temporal_keep: 1.0 + max_keep: null + - aspect_ratio: + - 0.75 + - 1.5 + num_blocks: 2 + spatial_scale: + - 0.7 + - 0.7 + temporal_scale: + - 1.0 + - 1.0 + max_temporal_keep: 1.0 + max_keep: null +meta: + load_checkpoint: false + read_checkpoint: null + seed: 234 + eval_freq: 100 + use_sdpa: true + dtype: bfloat16 +model: + model_name: vit_large + pred_depth: 12 + pred_embed_dim: 384 + uniform_power: true + use_mask_tokens: true + zero_init_mask_tokens: true +optimization: + ipe: 1 + ipe_scale: 1.25 + clip_grad: 10.0 + weight_decay: 0.04 + final_weight_decay: 0.4 + epochs: 300 + warmup: 40 + start_lr: 0.0002 + lr: 0.000625 + final_lr: 1.0e-06 + ema: + - 0.998 + - 1.0 diff --git a/dist/jepa-0.0.1-py3.9.egg b/dist/jepa-0.0.1-py3.9.egg new file mode 100644 index 00000000..f0c3d8b0 Binary files /dev/null and b/dist/jepa-0.0.1-py3.9.egg differ diff --git a/evals/video_classification_frozen/eval.py b/evals/video_classification_frozen/eval.py index f81f526d..26d64889 100644 --- a/evals/video_classification_frozen/eval.py +++ b/evals/video_classification_frozen/eval.py @@ -328,7 +328,12 @@ def run_one_epoch( for di in data[0] # iterate over temporal index of clip ] clip_indices = [d.to(device, non_blocking=True) for d in data[2]] + ### AIR Test1 + # print(data[1]) + # tensor_data_1 = torch.tensor(data[1], dtype=torch.float32) + # print(type(tensor_data_1)) labels = data[1].to(device) + # labels = tensor_data_1.to(device) batch_size = len(labels) # Forward and prediction diff --git a/src/datasets/video_dataset.py b/src/datasets/video_dataset.py index b05cc701..3bb2d181 100644 --- a/src/datasets/video_dataset.py +++ b/src/datasets/video_dataset.py @@ -128,7 +128,9 @@ def __init__( for data_path in self.data_paths: if data_path[-4:] == '.csv': - data = pd.read_csv(data_path, header=None, delimiter=" ") + ### 改 + data = pd.read_csv(data_path, header=None, delimiter=" ", on_bad_lines='skip') + # data = pd.read_csv(data_path, header=None, delimiter=" ") samples += list(data.values[:, 0]) labels += list(data.values[:, 1]) num_samples = len(data) diff --git a/src/jepa.egg-info/PKG-INFO b/src/jepa.egg-info/PKG-INFO new file mode 100644 index 00000000..0f8fd1dc --- /dev/null +++ b/src/jepa.egg-info/PKG-INFO @@ -0,0 +1,19 @@ +Metadata-Version: 2.1 +Name: jepa +Version: 0.0.1 +Summary: JEPA research code. +Requires-Python: >=3.9 +License-File: LICENSE +Requires-Dist: torch>=2 +Requires-Dist: torchvision +Requires-Dist: pyyaml +Requires-Dist: numpy +Requires-Dist: opencv-python +Requires-Dist: submitit +Requires-Dist: braceexpand +Requires-Dist: webdataset +Requires-Dist: timm +Requires-Dist: decord +Requires-Dist: pandas +Requires-Dist: einops +Requires-Dist: beartype diff --git a/src/jepa.egg-info/SOURCES.txt b/src/jepa.egg-info/SOURCES.txt new file mode 100644 index 00000000..9c7e7cd5 --- /dev/null +++ b/src/jepa.egg-info/SOURCES.txt @@ -0,0 +1,33 @@ +LICENSE +README.md +setup.py +src/datasets/data_manager.py +src/datasets/image_dataset.py +src/datasets/video_dataset.py +src/datasets/utils/weighted_sampler.py +src/datasets/utils/video/functional.py +src/datasets/utils/video/randaugment.py +src/datasets/utils/video/randerase.py +src/datasets/utils/video/transforms.py +src/datasets/utils/video/volume_transforms.py +src/jepa.egg-info/PKG-INFO +src/jepa.egg-info/SOURCES.txt +src/jepa.egg-info/dependency_links.txt +src/jepa.egg-info/requires.txt +src/jepa.egg-info/top_level.txt +src/masks/default.py +src/masks/multiblock3d.py +src/masks/random_tube.py +src/masks/utils.py +src/models/attentive_pooler.py +src/models/predictor.py +src/models/vision_transformer.py +src/models/utils/modules.py +src/models/utils/multimask.py +src/models/utils/patch_embed.py +src/models/utils/pos_embs.py +src/utils/distributed.py +src/utils/logging.py +src/utils/monitoring.py +src/utils/schedulers.py +src/utils/tensors.py \ No newline at end of file diff --git a/src/jepa.egg-info/dependency_links.txt b/src/jepa.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/jepa.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/jepa.egg-info/requires.txt b/src/jepa.egg-info/requires.txt new file mode 100644 index 00000000..d2970710 --- /dev/null +++ b/src/jepa.egg-info/requires.txt @@ -0,0 +1,13 @@ +torch>=2 +torchvision +pyyaml +numpy +opencv-python +submitit +braceexpand +webdataset +timm +decord +pandas +einops +beartype diff --git a/src/jepa.egg-info/top_level.txt b/src/jepa.egg-info/top_level.txt new file mode 100644 index 00000000..b421b2a2 --- /dev/null +++ b/src/jepa.egg-info/top_level.txt @@ -0,0 +1,4 @@ +datasets +masks +models +utils