Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion src/deepforest/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,59 @@ def inverse_transform(
return input[..., : size[0], : size[1]]


class PadIfNeeded(GeometricAugmentationBase2D):
r"""Pad the image to a fixed size if it is smaller than the given size.

Args:
size: Tuple of (height, width) specifying the target size.
pad_mode: Padding mode (constant, reflect, replicate, circular).
pad_value: Fill value for constant padding mode.
p: Probability of applying the transform.
same_on_batch: Apply same transformation to all batch elements.
keepdim: Maintain shape (not used for this transform).
"""

def __init__(
self,
size: tuple[int, int],
pad_mode: str = "constant",
pad_value: float = 0,
p: float = 0.5,
same_on_batch: bool = False,
keepdim: bool = False,
) -> None:
super().__init__(p=p, same_on_batch=same_on_batch, p_batch=1.0, keepdim=keepdim)
self.flags = {"pad_mode": pad_mode, "pad_value": pad_value, "size": size}

def compute_transformation(
self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]
) -> Tensor:
return self.identity_matrix(input)

def apply_transform(
self,
input: Tensor,
params: dict[str, Tensor],
flags: dict[str, Any],
transform: Tensor | None = None,
) -> Tensor:
target_h, target_w = flags["size"]
_, _, h, w = input.shape

pad_height = max(0, target_h - h)
pad_width = max(0, target_w - w)

if pad_height == 0 and pad_width == 0:
return input

return torch.nn.functional.pad(
input,
[0, pad_width, 0, pad_height],
mode=flags["pad_mode"],
value=flags["pad_value"],
)


class ZoomBlur(IntensityAugmentationBase2D):
"""Apply zoom blur effect by averaging multiple zoomed versions of the
image.
Expand Down Expand Up @@ -165,7 +218,11 @@ def apply_transform(self, input, params, flags, transform=None):
RandomPadTo,
{"pad_range": (0, 10), "pad_mode": "constant", "pad_value": 0, "p": 0.5},
),
"PadIfNeeded": (K.PadTo, {"size": (800, 800)}),
"PadIfNeeded": (
PadIfNeeded,
{"size": (800, 800), "pad_mode": "constant", "pad_value": 0},
),
"PadTo": (K.PadTo, {"size": (800, 800)}),
"Rotate": (K.RandomRotation, {"degrees": 15, "p": 0.5}),
"RandomBrightnessContrast": (
K.ColorJiggle,
Expand Down
25 changes: 25 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,35 @@ def test_random_pad_to():

output_image, output_bboxes = transform(image, bboxes)

assert output_image.shape == torch.Size([1, 3, 105, 105])
assert output_image.shape == torch.Size([1, 3, 105, 105])
assert torch.equal(output_bboxes, bboxes)


def test_pad_if_needed():
"""Test PadIfNeeded augmentation.

Should pad if smaller, do nothing if larger.
"""
# Case 1: Image smaller than target -> Should Pad
transform_pad = get_transform(augmentations={"PadIfNeeded": {"size": (800, 800)}})
image_small = torch.randn(1, 3, 600, 600)
bboxes = torch.tensor([[[10., 10., 50., 50.]]])

out_small, _ = transform_pad(image_small, bboxes)
assert out_small.shape == (1, 3, 800, 800), "Failed to pad small image"

# Case 2: Image larger than target -> Should NOT Crop (The Fix)
image_large = torch.randn(1, 3, 1000, 1000)
out_large, _ = transform_pad(image_large, bboxes)
assert out_large.shape == (1, 3, 1000, 1000), "Incorrectly cropped large image"

# Case 3: Verify PadTo alias still exists and crops (Backward Comp)
transform_crop = get_transform(augmentations={"PadTo": {"size": (800, 800)}})
out_crop, _ = transform_crop(image_large, bboxes)
assert out_crop.shape == (1, 3, 800, 800), "PadTo alias failed to crop"


def test_filter_boxes():
"""Test box filtering after augmentation."""
boxes = torch.tensor([
Expand Down