diff --git a/src/deepforest/datasets/cropmodel.py b/src/deepforest/datasets/cropmodel.py index 7ae0e0399..dd851d9e4 100644 --- a/src/deepforest/datasets/cropmodel.py +++ b/src/deepforest/datasets/cropmodel.py @@ -12,6 +12,8 @@ from torch.utils.data import Dataset from torchvision import transforms +from deepforest.utilities import apply_nodata_mask + def bounding_box_transform(augmentations=None, resize=None): """Create transform pipeline for bounding box data. @@ -121,7 +123,10 @@ def __getitem__(self, idx): row_off = int(ymin) width = int(max(1, xmax - xmin)) height = int(max(1, ymax - ymin)) - box = self.src.read(window=Window(col_off, row_off, width, height)) + # Clip window to image bounds to avoid out-of-bounds errors + window = Window(col_off, row_off, width, height) + window = window.intersection(Window(0, 0, self._image_width, self._image_height)) + box = apply_nodata_mask(self.src, window) box = np.rollaxis(box, 0, 3) if self.transform: diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index bb5b2d3b5..ee4d2cb16 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -11,7 +11,7 @@ from torch.utils.data import Dataset from deepforest import preprocess -from deepforest.utilities import format_geometry, read_file +from deepforest.utilities import apply_nodata_mask, format_geometry, read_file # Base prediction class @@ -470,7 +470,9 @@ def window_list(self): def get_crop(self, idx): window = self.windows[idx] with rio.open(self.path) as src: - window_data = src.read(window=Window(window.x, window.y, window.w, window.h)) + window_data = apply_nodata_mask( + src, Window(window.x, window.y, window.w, window.h) + ) # Rasterio already returns (C, H, W), just normalize and convert window_data = window_data.astype("float32") / 255.0 diff --git a/src/deepforest/model.py b/src/deepforest/model.py index f5addad21..01c77f3eb 100644 --- a/src/deepforest/model.py +++ b/src/deepforest/model.py @@ -296,7 +296,9 @@ def write_crops(self, root_dir, images, boxes, labels, savedir): xmin, ymin, xmax, ymax = square_box # Crop the image using the square box coordinates - img = src.read(window=((int(ymin), int(ymax)), (int(xmin), int(xmax)))) + img = utilities.apply_nodata_mask( + src, ((int(ymin), int(ymax)), (int(xmin), int(xmax))) + ) # Save the cropped image as a PNG file using opencv image_basename = os.path.splitext(os.path.basename(image))[0] img_path = os.path.join(savedir, label, f"{image_basename}_{index}.png") diff --git a/src/deepforest/utilities.py b/src/deepforest/utilities.py index 397515dd7..6d55c19f0 100644 --- a/src/deepforest/utilities.py +++ b/src/deepforest/utilities.py @@ -10,6 +10,7 @@ import xmltodict from omegaconf import DictConfig, OmegaConf from PIL import Image +from rasterio.windows import Window from tqdm import tqdm from deepforest import _ROOT @@ -82,6 +83,59 @@ def load_config( return config +def apply_nodata_mask(src, window): + """Read raster window and apply no-data value masking. + + This function reads a window from a rasterio dataset and masks no-data + values to 0 for more consistent predictions. If no nodata value is set, + the data are returned unmodified. + + Args: + src: rasterio.DatasetReader opened in 'r' mode + window: rasterio.windows.Window or tuple defining the window to read + + Returns: + numpy.ndarray: Raster data with shape (bands, height, width). No-data values + are set to 0. + """ + # Clip window to image bounds before reading to ensure consistent dimensions + if isinstance(window, Window): + full_window = Window(0, 0, src.width, src.height) + try: + window = window.intersection(full_window) + except rasterio.errors.WindowError as exc: + # Window is completely outside image bounds + raise ValueError( + f"Window {window} is completely outside image bounds " + f"(width={src.width}, height={src.height})" + ) from exc + + data = src.read(window=window) + + # Use rasterio's dataset_mask to get the mask (True = valid, False = nodata) + if src.nodata is not None: + mask = src.dataset_mask(window=window) + expected_height, expected_width = data.shape[1], data.shape[2] + if mask.shape[0] > expected_height: + mask = mask[:expected_height, :] + if mask.shape[1] > expected_width: + mask = mask[:, :expected_width] + if mask.shape != (expected_height, expected_width): + return data + nodata_mask = ~mask + assert nodata_mask.shape == (expected_height, expected_width), ( + f"nodata_mask shape {nodata_mask.shape} != expected {(expected_height, expected_width)}" + ) + # Set nodata pixels to 0 for all bands + # Apply 2D mask to each band: data shape is (bands, height, width) + # Use np.where to safely apply mask without indexing issues + for band_idx in range(data.shape[0]): + # np.where(condition, x, y): where condition is True, use x (0), else use y (data) + data[band_idx] = np.where(nodata_mask, 0, data[band_idx]) + + return data + + class DownloadProgressBar(tqdm): """Download progress bar class.""" @@ -654,11 +708,10 @@ def crop_raster(bounds, rgb_path=None, savedir=None, filename=None, driver="GTif driver = "PNG" else: # Read projected data using rasterio and crop - img = src.read( - window=rasterio.windows.from_bounds( - left, bottom, right, top, transform=src.transform - ) + window = rasterio.windows.from_bounds( + left, bottom, right, top, transform=src.transform ) + img = apply_nodata_mask(src, window) cropped_transform = rasterio.windows.transform( rasterio.windows.from_bounds( left, bottom, right, top, transform=src.transform diff --git a/tests/test_model_prediction.py b/tests/test_model_prediction.py new file mode 100644 index 000000000..cc3b8eda2 --- /dev/null +++ b/tests/test_model_prediction.py @@ -0,0 +1,32 @@ +import numpy as np +import pandas as pd +import pytest + +from deepforest import main + + +@pytest.mark.parametrize( + "model_name", + [ + "weecology/deepforest-bird", + # "weecology/deepforest-everglades-bird-species-detector", + # "weecology/deepforest-tree", + # "weecology/deepforest-livestock", + # "weecology/cropmodel-deadtrees", + ], +) +def test_white_image_predict_tile_no_predictions_bird_model(model_name): + """All-white image should yield no detections with various models.""" + m = main.deepforest() + m.create_trainer() + m.load_model(model_name) + # Create a white image (uint8 RGB) + white = np.full((2048, 2048, 3), 255, dtype=np.uint8) + res = m.predict_tile( + image=white, + patch_size=128, + patch_overlap=0.0, + iou_threshold=m.config.nms_thresh, + ) + assert len(res) == 0 + #assert (res is None) or (isinstance(res, pd.DataFrame) and res.empty)