From cd082bb645c36dc96fc08a23a60d9f20d7901782 Mon Sep 17 00:00:00 2001 From: Arkabandhu Chowdhury Date: Fri, 12 Aug 2022 07:26:20 -0700 Subject: [PATCH] Implement resize and train XRayVideo A/V with only resizing (#796) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/796 We want to check whether training XRayVideo with simply video resizing (in addition to other existing transformation like horizontal flipping and normalization) without random corp is sufficient. The resize dimension is used as 224*224. workflow: f362077622 (Note: in the workflow `fcc_mvit_dataset_v4p2_arkc.yaml` is used which I renamed to `fcc_mvit_dataset_v4p2_onlyresize.yaml` in this diff.) As can be seen, the validation MAP goes to around .422 as opposed to 0.46 when random resized crop is used (f355567669) and rest of the configuration is kept the same. Hence, it is better to keep random resized crop. Differential Revision: D38522980 fbshipit-source-id: 037690a4dccf9c3ee66b353792cb30b22bae8161 --- .../dataset/transforms/util_video.py | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/classy_vision/dataset/transforms/util_video.py b/classy_vision/dataset/transforms/util_video.py index 4c99b1232..8b4216086 100644 --- a/classy_vision/dataset/transforms/util_video.py +++ b/classy_vision/dataset/transforms/util_video.py @@ -32,6 +32,7 @@ class VideoConstants: MEAN = ImagenetConstants.MEAN # STD = ImagenetConstants.STD SIZE_RANGE = (128, 160) + RESCALE_SIZE = (224, 224) CROP_SIZE = 112 @@ -141,6 +142,44 @@ def __call__(self, clip): return clip +@register_transform("video_simple_resize") +class VideoSimpleResize(ClassyTransform): + """Given an input size, rescale the clip to the given size both in + height and width. + """ + + def __init__(self, rescale_size: List[int], interpolation_mode: str = "bilinear"): + """The constructor method of VideoClipResize class. + + Args: + rescale_size: size of the rescaled clip + interpolation_mode: Default: "bilinear". See valid values in + (https://pytorch.org/docs/stable/nn.functional.html#torch.nn. + functional.interpolate) + + """ + self.interpolation_mode = interpolation_mode + assert ( + len(rescale_size) == 2 + ), "rescale_size should be a list of size 2 (height, width)" + self.rescale_size = rescale_size + + def __call__(self, clip): + """Callable function which applies the tranform to the input clip. + + Args: + clip (torch.Tensor): input clip tensor + + """ + # clip size: C x T x H x W + clip = torch.nn.functional.interpolate( + clip, + size=self.rescale_size, + mode=self.interpolation_mode, + ) + return clip + + @register_transform("video_default_augment") class VideoDefaultAugmentTransform(ClassyTransform): """This is the default video transform with data augmentation which is useful for @@ -190,6 +229,53 @@ def __call__(self, video): return self._transform(video) +@register_transform("video_resize_augment") +class VideoResizeAugmentTransform(ClassyTransform): + """This is the resize video transform with data augmentation which is useful for + training. + + It sequentially prepares a torch.Tensor of video data, + resizes the video clip to specified size, randomly flips the + video clip horizontally, and normalizes the pixel values by mean subtraction + and standard deviation division. + + """ + + def __init__( + self, + rescale_size: List[int] = VideoConstants.RESCALE_SIZE, + mean: List[float] = VideoConstants.MEAN, + std: List[float] = VideoConstants.STD, + ): + """The constructor method of VideoResizeAugmentTransform class. + + Args: + size: the short edge of rescaled video clip + mean: a 3-tuple denoting the pixel RGB mean + std: a 3-tuple denoting the pixel RGB standard deviation + + """ + + self._transform = transforms.Compose( + [ + transforms_video.ToTensorVideo(), + # TODO(zyan3): migrate VideoClipRandomResizeCrop to TorchVision + VideoSimpleResize(rescale_size), + transforms_video.RandomHorizontalFlipVideo(), + transforms_video.NormalizeVideo(mean=mean, std=std), + ] + ) + + def __call__(self, video): + """Apply the resize transform with data augmentation to video. + + Args: + video: input video that will undergo the transform + + """ + return self._transform(video) + + @register_transform("video_default_no_augment") class VideoDefaultNoAugmentTransform(ClassyTransform): """This is the default video transform without data augmentation which is useful