diff --git a/src/deepforest/augmentations.py b/src/deepforest/augmentations.py index a6577068b..53aafcf59 100644 --- a/src/deepforest/augmentations.py +++ b/src/deepforest/augmentations.py @@ -87,6 +87,59 @@ def inverse_transform( return input[..., : size[0], : size[1]] +class PadIfNeeded(GeometricAugmentationBase2D): + r"""Pad the image to a fixed size if it is smaller than the given size. + + Args: + size: Tuple of (height, width) specifying the target size. + pad_mode: Padding mode (constant, reflect, replicate, circular). + pad_value: Fill value for constant padding mode. + p: Probability of applying the transform. + same_on_batch: Apply same transformation to all batch elements. + keepdim: Maintain shape (not used for this transform). + """ + + def __init__( + self, + size: tuple[int, int], + pad_mode: str = "constant", + pad_value: float = 0, + p: float = 0.5, + same_on_batch: bool = False, + keepdim: bool = False, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, p_batch=1.0, keepdim=keepdim) + self.flags = {"pad_mode": pad_mode, "pad_value": pad_value, "size": size} + + def compute_transformation( + self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any] + ) -> Tensor: + return self.identity_matrix(input) + + def apply_transform( + self, + input: Tensor, + params: dict[str, Tensor], + flags: dict[str, Any], + transform: Tensor | None = None, + ) -> Tensor: + target_h, target_w = flags["size"] + _, _, h, w = input.shape + + pad_height = max(0, target_h - h) + pad_width = max(0, target_w - w) + + if pad_height == 0 and pad_width == 0: + return input + + return torch.nn.functional.pad( + input, + [0, pad_width, 0, pad_height], + mode=flags["pad_mode"], + value=flags["pad_value"], + ) + + class ZoomBlur(IntensityAugmentationBase2D): """Apply zoom blur effect by averaging multiple zoomed versions of the image. @@ -165,7 +218,11 @@ def apply_transform(self, input, params, flags, transform=None): RandomPadTo, {"pad_range": (0, 10), "pad_mode": "constant", "pad_value": 0, "p": 0.5}, ), - "PadIfNeeded": (K.PadTo, {"size": (800, 800)}), + "PadIfNeeded": ( + PadIfNeeded, + {"size": (800, 800), "pad_mode": "constant", "pad_value": 0}, + ), + "PadTo": (K.PadTo, {"size": (800, 800)}), "Rotate": (K.RandomRotation, {"degrees": 15, "p": 0.5}), "RandomBrightnessContrast": ( K.ColorJiggle, diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index 3e5ae86fa..45ab1de4d 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -280,10 +280,35 @@ def test_random_pad_to(): output_image, output_bboxes = transform(image, bboxes) + assert output_image.shape == torch.Size([1, 3, 105, 105]) assert output_image.shape == torch.Size([1, 3, 105, 105]) assert torch.equal(output_bboxes, bboxes) +def test_pad_if_needed(): + """Test PadIfNeeded augmentation. + + Should pad if smaller, do nothing if larger. + """ + # Case 1: Image smaller than target -> Should Pad + transform_pad = get_transform(augmentations={"PadIfNeeded": {"size": (800, 800)}}) + image_small = torch.randn(1, 3, 600, 600) + bboxes = torch.tensor([[[10., 10., 50., 50.]]]) + + out_small, _ = transform_pad(image_small, bboxes) + assert out_small.shape == (1, 3, 800, 800), "Failed to pad small image" + + # Case 2: Image larger than target -> Should NOT Crop (The Fix) + image_large = torch.randn(1, 3, 1000, 1000) + out_large, _ = transform_pad(image_large, bboxes) + assert out_large.shape == (1, 3, 1000, 1000), "Incorrectly cropped large image" + + # Case 3: Verify PadTo alias still exists and crops (Backward Comp) + transform_crop = get_transform(augmentations={"PadTo": {"size": (800, 800)}}) + out_crop, _ = transform_crop(image_large, bboxes) + assert out_crop.shape == (1, 3, 800, 800), "PadTo alias failed to crop" + + def test_filter_boxes(): """Test box filtering after augmentation.""" boxes = torch.tensor([