Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions classy_vision/dataset/transforms/util_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class VideoConstants:
MEAN = ImagenetConstants.MEAN #
STD = ImagenetConstants.STD
SIZE_RANGE = (128, 160)
RESCALE_SIZE = (224, 224)
CROP_SIZE = 112


Expand Down Expand Up @@ -141,6 +142,44 @@ def __call__(self, clip):
return clip


@register_transform("video_simple_resize")
class VideoSimpleResize(ClassyTransform):
"""Given an input size, rescale the clip to the given size both in
height and width.
"""

def __init__(self, rescale_size: List[int], interpolation_mode: str = "bilinear"):
"""The constructor method of VideoClipResize class.

Args:
rescale_size: size of the rescaled clip
interpolation_mode: Default: "bilinear". See valid values in
(https://pytorch.org/docs/stable/nn.functional.html#torch.nn.
functional.interpolate)

"""
self.interpolation_mode = interpolation_mode
assert (
len(rescale_size) == 2
), "rescale_size should be a list of size 2 (height, width)"
self.rescale_size = rescale_size

def __call__(self, clip):
"""Callable function which applies the tranform to the input clip.

Args:
clip (torch.Tensor): input clip tensor

"""
# clip size: C x T x H x W
clip = torch.nn.functional.interpolate(
clip,
size=self.rescale_size,
mode=self.interpolation_mode,
)
return clip


@register_transform("video_default_augment")
class VideoDefaultAugmentTransform(ClassyTransform):
"""This is the default video transform with data augmentation which is useful for
Expand Down Expand Up @@ -190,6 +229,53 @@ def __call__(self, video):
return self._transform(video)


@register_transform("video_resize_augment")
class VideoResizeAugmentTransform(ClassyTransform):
"""This is the resize video transform with data augmentation which is useful for
training.

It sequentially prepares a torch.Tensor of video data,
resizes the video clip to specified size, randomly flips the
video clip horizontally, and normalizes the pixel values by mean subtraction
and standard deviation division.

"""

def __init__(
self,
rescale_size: List[int] = VideoConstants.RESCALE_SIZE,
mean: List[float] = VideoConstants.MEAN,
std: List[float] = VideoConstants.STD,
):
"""The constructor method of VideoResizeAugmentTransform class.

Args:
size: the short edge of rescaled video clip
mean: a 3-tuple denoting the pixel RGB mean
std: a 3-tuple denoting the pixel RGB standard deviation

"""

self._transform = transforms.Compose(
[
transforms_video.ToTensorVideo(),
# TODO(zyan3): migrate VideoClipRandomResizeCrop to TorchVision
VideoSimpleResize(rescale_size),
transforms_video.RandomHorizontalFlipVideo(),
transforms_video.NormalizeVideo(mean=mean, std=std),
]
)

def __call__(self, video):
"""Apply the resize transform with data augmentation to video.

Args:
video: input video that will undergo the transform

"""
return self._transform(video)


@register_transform("video_default_no_augment")
class VideoDefaultNoAugmentTransform(ClassyTransform):
"""This is the default video transform without data augmentation which is useful
Expand Down