diff --git a/classy_vision/dataset/transforms/util_video.py b/classy_vision/dataset/transforms/util_video.py index 4c99b1232..8b4216086 100644 --- a/classy_vision/dataset/transforms/util_video.py +++ b/classy_vision/dataset/transforms/util_video.py @@ -32,6 +32,7 @@ class VideoConstants: MEAN = ImagenetConstants.MEAN # STD = ImagenetConstants.STD SIZE_RANGE = (128, 160) + RESCALE_SIZE = (224, 224) CROP_SIZE = 112 @@ -141,6 +142,44 @@ def __call__(self, clip): return clip +@register_transform("video_simple_resize") +class VideoSimpleResize(ClassyTransform): + """Given an input size, rescale the clip to the given size both in + height and width. + """ + + def __init__(self, rescale_size: List[int], interpolation_mode: str = "bilinear"): + """The constructor method of VideoClipResize class. + + Args: + rescale_size: size of the rescaled clip + interpolation_mode: Default: "bilinear". See valid values in + (https://pytorch.org/docs/stable/nn.functional.html#torch.nn. + functional.interpolate) + + """ + self.interpolation_mode = interpolation_mode + assert ( + len(rescale_size) == 2 + ), "rescale_size should be a list of size 2 (height, width)" + self.rescale_size = rescale_size + + def __call__(self, clip): + """Callable function which applies the tranform to the input clip. + + Args: + clip (torch.Tensor): input clip tensor + + """ + # clip size: C x T x H x W + clip = torch.nn.functional.interpolate( + clip, + size=self.rescale_size, + mode=self.interpolation_mode, + ) + return clip + + @register_transform("video_default_augment") class VideoDefaultAugmentTransform(ClassyTransform): """This is the default video transform with data augmentation which is useful for @@ -190,6 +229,53 @@ def __call__(self, video): return self._transform(video) +@register_transform("video_resize_augment") +class VideoResizeAugmentTransform(ClassyTransform): + """This is the resize video transform with data augmentation which is useful for + training. + + It sequentially prepares a torch.Tensor of video data, + resizes the video clip to specified size, randomly flips the + video clip horizontally, and normalizes the pixel values by mean subtraction + and standard deviation division. + + """ + + def __init__( + self, + rescale_size: List[int] = VideoConstants.RESCALE_SIZE, + mean: List[float] = VideoConstants.MEAN, + std: List[float] = VideoConstants.STD, + ): + """The constructor method of VideoResizeAugmentTransform class. + + Args: + size: the short edge of rescaled video clip + mean: a 3-tuple denoting the pixel RGB mean + std: a 3-tuple denoting the pixel RGB standard deviation + + """ + + self._transform = transforms.Compose( + [ + transforms_video.ToTensorVideo(), + # TODO(zyan3): migrate VideoClipRandomResizeCrop to TorchVision + VideoSimpleResize(rescale_size), + transforms_video.RandomHorizontalFlipVideo(), + transforms_video.NormalizeVideo(mean=mean, std=std), + ] + ) + + def __call__(self, video): + """Apply the resize transform with data augmentation to video. + + Args: + video: input video that will undergo the transform + + """ + return self._transform(video) + + @register_transform("video_default_no_augment") class VideoDefaultNoAugmentTransform(ClassyTransform): """This is the default video transform without data augmentation which is useful