Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/deepforest/datasets/cropmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from torch.utils.data import Dataset
from torchvision import transforms

from deepforest.utilities import apply_nodata_mask


def bounding_box_transform(augmentations=None, resize=None):
"""Create transform pipeline for bounding box data.
Expand Down Expand Up @@ -121,7 +123,10 @@ def __getitem__(self, idx):
row_off = int(ymin)
width = int(max(1, xmax - xmin))
height = int(max(1, ymax - ymin))
box = self.src.read(window=Window(col_off, row_off, width, height))
# Clip window to image bounds to avoid out-of-bounds errors
window = Window(col_off, row_off, width, height)
window = window.intersection(Window(0, 0, self._image_width, self._image_height))
box = apply_nodata_mask(self.src, window)
box = np.rollaxis(box, 0, 3)

if self.transform:
Expand Down
6 changes: 4 additions & 2 deletions src/deepforest/datasets/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch.utils.data import Dataset

from deepforest import preprocess
from deepforest.utilities import format_geometry, read_file
from deepforest.utilities import apply_nodata_mask, format_geometry, read_file


# Base prediction class
Expand Down Expand Up @@ -470,7 +470,9 @@ def window_list(self):
def get_crop(self, idx):
window = self.windows[idx]
with rio.open(self.path) as src:
window_data = src.read(window=Window(window.x, window.y, window.w, window.h))
window_data = apply_nodata_mask(
src, Window(window.x, window.y, window.w, window.h)
)

# Rasterio already returns (C, H, W), just normalize and convert
window_data = window_data.astype("float32") / 255.0
Expand Down
4 changes: 3 additions & 1 deletion src/deepforest/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ def write_crops(self, root_dir, images, boxes, labels, savedir):
xmin, ymin, xmax, ymax = square_box

# Crop the image using the square box coordinates
img = src.read(window=((int(ymin), int(ymax)), (int(xmin), int(xmax))))
img = utilities.apply_nodata_mask(
src, ((int(ymin), int(ymax)), (int(xmin), int(xmax)))
)
# Save the cropped image as a PNG file using opencv
image_basename = os.path.splitext(os.path.basename(image))[0]
img_path = os.path.join(savedir, label, f"{image_basename}_{index}.png")
Expand Down
61 changes: 57 additions & 4 deletions src/deepforest/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import xmltodict
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from rasterio.windows import Window
from tqdm import tqdm

from deepforest import _ROOT
Expand Down Expand Up @@ -82,6 +83,59 @@ def load_config(
return config


def apply_nodata_mask(src, window):
"""Read raster window and apply no-data value masking.

This function reads a window from a rasterio dataset and masks no-data
values to 0 for more consistent predictions. If no nodata value is set,
the data are returned unmodified.

Args:
src: rasterio.DatasetReader opened in 'r' mode
window: rasterio.windows.Window or tuple defining the window to read

Returns:
numpy.ndarray: Raster data with shape (bands, height, width). No-data values
are set to 0.
"""
# Clip window to image bounds before reading to ensure consistent dimensions
if isinstance(window, Window):
full_window = Window(0, 0, src.width, src.height)
try:
window = window.intersection(full_window)
except rasterio.errors.WindowError as exc:
# Window is completely outside image bounds
raise ValueError(
f"Window {window} is completely outside image bounds "
f"(width={src.width}, height={src.height})"
) from exc

data = src.read(window=window)

# Use rasterio's dataset_mask to get the mask (True = valid, False = nodata)
if src.nodata is not None:
mask = src.dataset_mask(window=window)
expected_height, expected_width = data.shape[1], data.shape[2]
if mask.shape[0] > expected_height:
mask = mask[:expected_height, :]
if mask.shape[1] > expected_width:
mask = mask[:, :expected_width]
if mask.shape != (expected_height, expected_width):
return data
nodata_mask = ~mask
assert nodata_mask.shape == (expected_height, expected_width), (
f"nodata_mask shape {nodata_mask.shape} != expected {(expected_height, expected_width)}"
)
# Set nodata pixels to 0 for all bands
# Apply 2D mask to each band: data shape is (bands, height, width)
# Use np.where to safely apply mask without indexing issues
for band_idx in range(data.shape[0]):
# np.where(condition, x, y): where condition is True, use x (0), else use y (data)
data[band_idx] = np.where(nodata_mask, 0, data[band_idx])

return data


class DownloadProgressBar(tqdm):
"""Download progress bar class."""

Expand Down Expand Up @@ -654,11 +708,10 @@ def crop_raster(bounds, rgb_path=None, savedir=None, filename=None, driver="GTif
driver = "PNG"
else:
# Read projected data using rasterio and crop
img = src.read(
window=rasterio.windows.from_bounds(
left, bottom, right, top, transform=src.transform
)
window = rasterio.windows.from_bounds(
left, bottom, right, top, transform=src.transform
)
img = apply_nodata_mask(src, window)
cropped_transform = rasterio.windows.transform(
rasterio.windows.from_bounds(
left, bottom, right, top, transform=src.transform
Expand Down
32 changes: 32 additions & 0 deletions tests/test_model_prediction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import numpy as np
import pandas as pd
import pytest

from deepforest import main


@pytest.mark.parametrize(
"model_name",
[
"weecology/deepforest-bird",
# "weecology/deepforest-everglades-bird-species-detector",
# "weecology/deepforest-tree",
# "weecology/deepforest-livestock",
# "weecology/cropmodel-deadtrees",
],
)
def test_white_image_predict_tile_no_predictions_bird_model(model_name):
"""All-white image should yield no detections with various models."""
m = main.deepforest()
m.create_trainer()
m.load_model(model_name)
# Create a white image (uint8 RGB)
white = np.full((2048, 2048, 3), 255, dtype=np.uint8)
res = m.predict_tile(
image=white,
patch_size=128,
patch_overlap=0.0,
iou_threshold=m.config.nms_thresh,
)
assert len(res) == 0
#assert (res is None) or (isinstance(res, pd.DataFrame) and res.empty)
Loading