From 467edb9d0a6f96dffb9494b9ed35858f9bd5f4f2 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Tue, 30 Dec 2025 09:46:25 -0800 Subject: [PATCH] Simplify evaluation and prediction to mirror training - Make trainer.validate the preferred evaluation method and standardize train, eval and predict to accept lists not batches. - Add a sublist concept for MultiImage datasets - Ensure predict_file follows the root_dir criteria for read_file --- docs/user_guide/12_evaluation.md | 17 +- src/deepforest/datasets/prediction.py | 276 ++++++++++++-------------- src/deepforest/main.py | 151 +++++++++----- src/deepforest/model.py | 10 +- src/deepforest/predict.py | 20 +- tests/test_datasets_prediction.py | 34 +--- tests/test_evaluate.py | 103 +++++++++- tests/test_main.py | 39 +++- 8 files changed, 403 insertions(+), 247 deletions(-) diff --git a/docs/user_guide/12_evaluation.md b/docs/user_guide/12_evaluation.md index ee3527ae0..1d856a712 100644 --- a/docs/user_guide/12_evaluation.md +++ b/docs/user_guide/12_evaluation.md @@ -11,17 +11,24 @@ DeepForest allows users to assess model performance compared to ground-truth dat 5. mAP - Mean-Average-Precision, a computer vision metric that assesses the performance of the model incoporating precision, recall and average score of true positives. See below. ## Evaluation code - -The model's .evaluate method takes a set of labels in the form of a CSV file that includes paths to images and the coordinates of associated labels as well as thresholds to determine if a prediction is close enough to a label to be considered a match. +We can use the trainer.validate method to generate the validation predictions and get the evaluation metrics. ```python from deepforest import main, get_data +import os +import pandas as pd m = main.deepforest() m.load_model("Weecology/deepforest-tree") -# Sample data csv_file = get_data("OSBS_029.csv") -results = m.evaluate(csv_file, iou_threshold=0.4) +m.config.validation.csv_file = csv_file +m.config.validation.root_dir = os.path.dirname(csv_file) +m.create_trainer() +# Runs validation, logs IoU and overall mAP +m.trainer.validate(m) + +# Access predictions dataframe if desired +predictions = pd.concat(m.predictions) ``` This produces a dictionary that contains a detailed result comparison for each label, the aggregate metrics, the predictions data frame, and the ground truth data frame. @@ -45,6 +52,8 @@ mAP is the standard COCO evaluation metric and the most common for comparing com For information on how to calculate mAP, see the [torchmetrics documentation](https://torchmetrics.readthedocs.io/en/stable/detection/mean_average_precision.html) and further reading below. + + ### Precision and Recall at a set IoU threshold. This was the original DeepForest metric, set to an IoU of 0.4. This means that all predictions that overlap a ground truth box at IoU > 0.4 are true positives. As opposed to the torchmetrics above, it is intuitive and matches downstream ecological tasks. The drawback is that it is slow, coarse, and does not fully reward the model for having high confidence scores on true positives. diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index f366be5c2..bb5b2d3b5 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -1,8 +1,6 @@ import os -import cv2 import numpy as np -import numpy.typing as npt import pandas as pd import rasterio as rio import slidingwindow @@ -10,10 +8,10 @@ from PIL import Image from rasterio.windows import Window from torch.nn import functional as F -from torch.utils.data import Dataset, default_collate +from torch.utils.data import Dataset from deepforest import preprocess -from deepforest.utilities import format_geometry +from deepforest.utilities import format_geometry, read_file # Base prediction class @@ -27,7 +25,6 @@ class PredictionDataset(Dataset): image (PIL.Image.Image): A single image. path (str): Path to a single image. images (List[PIL.Image.Image]): A list of images. - paths (List[str]): A list of image paths. patch_size (int): Size of the patches to extract. patch_overlap (float): Overlap between patches. size (int): Target size to resize images to. Optional; if not provided, no resizing is performed. @@ -38,83 +35,46 @@ def __init__( image=None, path=None, images=None, - paths=None, patch_size=400, patch_overlap=0, - size=None, ): self.image = image self.images = images self.path = path - self.paths = paths self.patch_size = patch_size self.patch_overlap = patch_overlap - self.size = size self.items = self.prepare_items() - def _load_and_preprocess_image( - self, - image_path: str | None = None, - image: Image.Image | npt.NDArray | None = None, - size: int | None = None, - preprocess_image: bool = True, + def load_and_preprocess_image( + self, image_path: str = None, image: np.ndarray | Image.Image = None ): - """Load and preprocess an image. Either an image path or PIL image must - be provided. - - Datasets should load using PIL and transpose the image to - (C, H, W) before main.model.forward() is called. - - Args: - image_path: (str) path to image, optional - image: (PIL image), optional - size: (int) output size - preprocess_image: (bool) Whether to convert the image to a float32 array between 0 and 1. - - Returns: - CHW float32 numpy array, normalized to be in [0, 1] - """ - if image is None and image_path is None: - raise ValueError("Either image or image_path must be provided") - elif image is None: - image = Image.open(image_path) - - if isinstance(image, Image.Image) and image.mode != "RGB": - raise ValueError( - f"Expected 8-bit 3-channel RGB, got {image.mode}, {len(image.getbands())} channels and size: {image.size}." - "Check for transparent alpha channel and remove if present." - ) - elif isinstance(image, np.ndarray) and ( - image.ndim != 3 or image.shape[2] != 3 or image.dtype != np.uint8 - ): - raise ValueError( - f"Expected 8-bit 3-channel RGB numpy array, got {image.ndim} dimensions and shape: {image.shape}." - "Check for transparent alpha channel and remove if present." - ) - - image = np.array(image) - - if preprocess_image: - image = self.preprocess_image(image, size) - - return image - - def preprocess_image(self, image: npt.NDArray, size=None) -> npt.NDArray: - """Preprocess an 8-bit image to a float32 array between 0 and 1.""" - image = image.astype(np.float32) - image /= 255.0 - - if size is not None: - image = self.resize_image(image, size) + if image is None: + if image_path is None: + raise ValueError("Either image_path or image must be provided") + image = np.array(Image.open(image_path).convert("RGB")) + else: + image = np.array(image) + # If dtype is not float32, convert to float32 + if image.dtype != "float32": + image = image.astype("float32") + + # If image is not normalized, normalize to [0, 1] + if image.max() > 1 or image.min() < 0: + image = image / 255.0 + + # If image is not in CHW format, convert to CHW + if image.shape[0] != 3: + if image.shape[-1] != 3: + raise ValueError( + f"Expected 3 channel image, got image shape {image.shape}" + ) + else: + image = np.rollaxis(image, 2, 0) - image = np.transpose(image, (2, 0, 1)) + image = torch.from_numpy(image) return image - def resize_image(self, image: npt.NDArray, size: int) -> npt.NDArray: - """Resize an image to a new (square) size.""" - return cv2.resize(image, dsize=(size, size)) - def prepare_items(self): """Prepare the items for the dataset. @@ -131,15 +91,9 @@ def __getitem__(self, idx): return self.get_crop(idx) def collate_fn(self, batch): - """Collate the batch into a single tensor.""" - # Check if all images in batch have same dimensions - try: - return default_collate(batch) - except RuntimeError: - raise RuntimeError( - "Images in batch have different dimensions. " - "Set validation.size in config.yaml to resize all images to a common size." - ) from None + """Collate the batch into a list.""" + + return batch def get_crop_bounds(self, idx): """Get the crop bounds at the given index, needed to mosaic @@ -171,12 +125,12 @@ def determine_geometry_type(self, batched_result): return geom_type def format_batch(self, batch, idx, sub_idx=None): - """Format the batch into a single dataframe. + """Format a single prediction dict into a dataframe with metadata. Args: - batch (list): The batch to format. - idx (int): The index of the batch. - sub_idx (int): The index of the subbatch. If None, the index is the subbatch index. + batch: A single prediction dict (keys: boxes, labels, scores, etc.) + idx: The dataset index (image index for windowed datasets) + sub_idx: The sub-index (window index). If None, uses idx. """ if sub_idx is None: sub_idx = idx @@ -184,38 +138,26 @@ def format_batch(self, batch, idx, sub_idx=None): result = format_geometry(batch, geom_type=geom_type) if result is None: return None - result["window_xmin"] = self.get_crop_bounds(sub_idx)[0] - result["window_ymin"] = self.get_crop_bounds(sub_idx)[1] + + crop_bounds = self.get_crop_bounds(sub_idx) + if crop_bounds is not None: + result["window_xmin"] = crop_bounds[0] + result["window_ymin"] = crop_bounds[1] result["image_path"] = self.get_image_basename(idx) return result - def postprocess(self, batched_result): - """Postprocess the batched result into a single dataframe. + def postprocess(self, batch, prediction_index): + """Postprocess a single prediction result into a dataframe. - In the case of sub-batches, the index is the sub-batch index. + Args: + batch: A single prediction dict (keys: boxes, labels, scores, etc.) + prediction_index: The index of this item in the dataset """ - formatted_result = [] - for idx, batch in enumerate(batched_result): - if isinstance(batch, list): - for sub_idx, sub_batch in enumerate(batch): - result = self.format_batch(sub_batch, idx, sub_idx) - if result is not None: - formatted_result.append(result) - else: - result = self.format_batch(batch, idx) - if result is not None: - formatted_result.append(result) - - if len(formatted_result) > 0: - formatted_result = pd.concat(formatted_result) - else: - formatted_result = pd.DataFrame() - - # reset index - formatted_result = formatted_result.reset_index(drop=True) - - return formatted_result + result = self.format_batch(batch, prediction_index) + if result is None: + return pd.DataFrame() + return result.reset_index(drop=True) class SingleImage(PredictionDataset): @@ -227,13 +169,7 @@ def __init__(self, path=None, image=None, patch_size=400, patch_overlap=0): ) def prepare_items(self): - self.image = self._load_and_preprocess_image( - self.path, self.image, preprocess_image=False - ) - - # Seperately transpose the image to channels first - self.image = np.transpose(self.image, (2, 0, 1)) - + self.image = self.load_and_preprocess_image(self.path, image=self.image) self.windows = preprocess.compute_windows( self.image, self.patch_size, self.patch_overlap ) @@ -246,9 +182,6 @@ def window_list(self): def get_crop(self, idx): crop = self.image[self.windows[idx].indices()] - crop = self.preprocess_image(crop) - if crop.shape[0] != 3: - crop = np.transpose(crop, (1, 2, 0)) return crop @@ -266,14 +199,15 @@ class FromCSVFile(PredictionDataset): """Take in a csv file with image paths and preprocess and batch together.""" - def __init__(self, csv_file: str, root_dir: str, size: int = None): + def __init__(self, csv_file: str, root_dir: str): self.csv_file = csv_file self.root_dir = root_dir - super().__init__(size=size) - self.prepare_items() + super().__init__() def prepare_items(self): - self.annotations = pd.read_csv(self.csv_file) + self.annotations = read_file(self.csv_file) + if self.root_dir is None: + self.root_dir = self.annotations.root_dir self.image_names = self.annotations.image_path.unique() self.image_paths = [os.path.join(self.root_dir, x) for x in self.image_names] @@ -281,7 +215,7 @@ def __len__(self): return len(self.image_paths) def get_crop(self, idx): - image = self._load_and_preprocess_image(self.image_paths[idx], size=self.size) + image = self.load_and_preprocess_image(image_path=self.image_paths[idx]) return image def get_image_basename(self, idx): @@ -291,19 +225,20 @@ def get_crop_bounds(self, idx): return None def format_batch(self, batch, idx, sub_idx=None): - """Format the batch into a single dataframe. + """Format a single prediction dict into a dataframe with metadata. + Override of base class to skip window coordinates (not applicable for + full images). Args: - batch (list): The batch to format. - idx (int): The index of the batch. - sub_idx (int): The index of the subbatch. If None, the index is the subbatch index. + batch: A single prediction dict (keys: boxes, labels, scores, etc.) + idx: The dataset index (image index) + sub_idx: Unused (kept for compatibility with base class signature) """ - if sub_idx is None: - sub_idx = idx geom_type = self.determine_geometry_type(batch) result = format_geometry(batch, geom_type=geom_type) if result is None: return None + result["image_path"] = self.get_image_basename(idx) return result @@ -312,7 +247,7 @@ def format_batch(self, batch, idx, sub_idx=None): class MultiImage(PredictionDataset): """Take in a list of image paths, preprocess and batch together. - Note: This dataset will load the first image to determine the image dimensions. + Note: This dataset will load the first image to determine the image dimensions. Images are expected to be the same size. For variable sized images, write a csv file and use the FromCSVFile dataset. """ def __init__(self, paths: list[str], patch_size: int, patch_overlap: float): @@ -322,15 +257,12 @@ def __init__(self, paths: list[str], patch_size: int, patch_overlap: float): patch_size (int): Size of the patches to extract. patch_overlap (float): Overlap between patches. """ - # Runtime type checking - if not isinstance(paths, list): - raise TypeError(f"paths must be a list, got {type(paths)}") - self.paths = paths self.patch_size = patch_size self.patch_overlap = patch_overlap + self.sublist_lengths = [] - image = self._load_and_preprocess_image(self.paths[0]) + image = self.load_and_preprocess_image(image_path=self.paths[0]) self.image_height = image.shape[1] self.image_width = image.shape[2] @@ -379,7 +311,7 @@ def create_overlapping_views(self, input_tensor, size, overlap): return output def _create_patches(self, image): - image_tensor = torch.tensor(image).unsqueeze(0) # Convert to (N, C, H, W) + image_tensor = image.unsqueeze(0) # Convert to (N, C, H, W) patch_overlap_size = int(self.patch_size * self.patch_overlap) patches = self.create_overlapping_views( image_tensor, self.patch_size, patch_overlap_size @@ -416,15 +348,37 @@ def window_list(self): return windows def collate_fn(self, batch): - # Comes pre-batched - return batch + """Collate the batch into a single list of crops. + + Keep track of the lengths of each sublist. + """ + # Create a list of lengths of each sublist + sub_list_length = [ + [idx, sub_idx] + for idx, sublist in enumerate(batch) + for sub_idx in range(len(sublist)) + ] + self.sublist_lengths.append(sub_list_length) + + # Flatten list of lists of crops + flattened_batch = [crop for sublist in batch for crop in sublist] + sublist_lengths = [ + [idx, sub_idx] + for idx, sublist in enumerate(batch) + for sub_idx in range(len(sublist)) + ] + + return {"images": flattened_batch, "sublist_lengths": sublist_lengths} def __len__(self): return len(self.paths) def get_crop(self, idx): - image = self._load_and_preprocess_image(self.paths[idx]) - return self._create_patches(image) + image = self.load_and_preprocess_image(image_path=self.paths[idx]) + crops = self._create_patches(image) + + # Return as a list of crops, each with shape (3, 300, 300) + return [crops[i] for i in range(crops.shape[0])] def get_image_basename(self, idx): return os.path.basename(self.paths[idx]) @@ -432,6 +386,36 @@ def get_image_basename(self, idx): def get_crop_bounds(self, idx): return self.window_list()[idx] + def postprocess(self, batch, prediction_index, original_batch_structure): + """Postprocess flattened batch of predictions from multiple images. + + Args: + batch: List of prediction dicts (all windows from all images in batch) + prediction_index: Index of this batch from trainer.predict + """ + if prediction_index >= len(original_batch_structure): + raise ValueError( + f"prediction_index {prediction_index} exceeds sublist_lengths length {len(original_batch_structure)}. " + "This may indicate a mismatch between collate_fn calls and postprocess calls." + ) + + batch_sublist_lengths = original_batch_structure[prediction_index] + formatted_results = [] + + # batch_sublist_lengths[i] = [image_idx, window_idx] corresponds to batch[i] + for batch_position, (image_idx, window_idx) in enumerate(batch_sublist_lengths): + prediction = batch[batch_position] + + # Format with correct image index and window index + result = self.format_batch(prediction, image_idx, window_idx) + if result is not None: + formatted_results.append(result) + + if len(formatted_results) > 0: + return pd.concat(formatted_results).reset_index(drop=True) + else: + return pd.DataFrame() + class TiledRaster(PredictionDataset): """Dataset for predicting on raster windows. @@ -447,13 +431,9 @@ class TiledRaster(PredictionDataset): """ def __init__(self, path, patch_size, patch_overlap): - self.path = path - self.patch_size = patch_size - self.patch_overlap = patch_overlap - self.prepare_items() - if path is None: raise ValueError("path is required for a memory raster dataset") + super().__init__(path=path, patch_size=patch_size, patch_overlap=patch_overlap) def prepare_items(self): # Get raster shape without keeping file open @@ -492,9 +472,13 @@ def get_crop(self, idx): with rio.open(self.path) as src: window_data = src.read(window=Window(window.x, window.y, window.w, window.h)) - # Convert to torch tensor and rearrange dimensions - window_data = torch.from_numpy(window_data).float() # Convert to torch tensor - window_data = window_data / 255.0 # Normalize + # Rasterio already returns (C, H, W), just normalize and convert + window_data = window_data.astype("float32") / 255.0 + window_data = torch.from_numpy(window_data).float() + if window_data.shape[0] != 3: + raise ValueError( + f"Expected 3 channel image, got {window_data.shape[0]} channels" + ) return window_data diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 81575d950..33e2dca82 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -82,6 +82,7 @@ def __init__( self.create_trainer() self.model = model + self.original_batch_structure = [] if self.model is None: self.create_model() @@ -475,7 +476,10 @@ def predict_image(self, image: np.ndarray | None = None, path: str | None = None return result def predict_file( - self, csv_file, root_dir, crop_model=None, size=None, batch_size=None + self, + csv_file, + root_dir, + crop_model=None, ): """Create a dataset and predict entire annotation file CSV file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" @@ -492,8 +496,8 @@ def predict_file( df: pandas dataframe with bounding boxes, label and scores for each image in the csv file """ - ds = prediction.FromCSVFile(csv_file=csv_file, root_dir=root_dir, size=size) - dataloader = self.predict_dataloader(ds, batch_size=batch_size) + ds = prediction.FromCSVFile(csv_file=csv_file, root_dir=root_dir) + dataloader = self.predict_dataloader(ds, batch_size=self.config.batch_size) results = predict._dataloader_wrapper_( model=self, crop_model=crop_model, @@ -583,31 +587,45 @@ def predict_tile( patch_size=patch_size, ) - batched_results = self.trainer.predict(self, self.predict_dataloader(ds)) + dataloader = self.predict_dataloader(ds) + batched_results = self.trainer.predict(self, dataloader) # Flatten list from batched prediction - prediction_list = [] - for batch in batched_results: - for images in batch: - prediction_list.append(images) - image_results.append(ds.postprocess(prediction_list)) + # Track global window index across batches + global_window_idx = 0 + for _idx, batch in enumerate(batched_results): + for _window_idx, window_result in enumerate(batch): + formatted_result = ds.postprocess( + window_result, global_window_idx + ) + image_results.append(formatted_result) + global_window_idx += 1 - results = pd.concat(image_results) + if not image_results: + results = pd.DataFrame() + else: + results = pd.concat(image_results) elif dataloader_strategy == "batch": + self.original_batch_structure.clear() ds = prediction.MultiImage( paths=paths, patch_overlap=patch_overlap, patch_size=patch_size ) - batched_results = self.trainer.predict(self, self.predict_dataloader(ds)) + dataloader = self.predict_dataloader(ds) + batched_results = self.trainer.predict(self, dataloader) # Flatten list from batched prediction - prediction_list = [] - for batch in batched_results: - for images in batch: - prediction_list.append(images) - image_results.append(ds.postprocess(prediction_list)) - results = pd.concat(image_results) + for idx, batch in enumerate(batched_results): + formatted_result = ds.postprocess( + batch, idx, self.original_batch_structure + ) + image_results.append(formatted_result) + + if not image_results: + results = pd.DataFrame() + else: + results = pd.concat(image_results) else: raise ValueError(f"Invalid dataloader_strategy: {dataloader_strategy}") @@ -711,7 +729,7 @@ def validation_step(self, batch, batch_idx): pass # In eval model, return predictions to calculate prediction metrics - preds = self.model.eval() + self.model.eval() with torch.no_grad(): preds = self.model.forward(images, targets) @@ -816,8 +834,13 @@ def log_epoch_metrics(self): self.iou_metric.reset() output = self.mAP_metric.compute() - # Remove classes from output dict - output = {key: value for key, value in output.items() if not key == "classes"} + # Keep only overall mAP; drop extra map_* and classes clutter + if isinstance(output, dict): + # Remove classes entry if present + if "classes" in output: + output.pop("classes", None) + # Reduce to only overall 'map' and map_50 if available + output = {k: v for k, v in output.items() if k in ["map", "map_50"]} try: self.log_dict(output) except MisconfigurationException: @@ -849,10 +872,9 @@ def on_validation_epoch_end(self): else: predictions = pd.DataFrame() - results = self.evaluate( + results = self.__evaluate__( self.config.validation.csv_file, root_dir=self.config.validation.root_dir, - size=self.config.validation.size, predictions=predictions, ) @@ -871,28 +893,18 @@ def predict_step(self, batch, batch_idx): Returns: """ - split_results = False - # If batch is a list, concatenate the images, predict and then split the results - if isinstance(batch, list): - original_list_length = len(batch) - combined_batch = torch.cat(batch, dim=0) - split_results = True - else: - combined_batch = batch - - batch_results = self.model(combined_batch) - - # If batch is a list, split the results - if split_results: - results = [] - batch_size = len(batch_results) // original_list_length - for i in range(original_list_length): - start_idx = i * batch_size - end_idx = start_idx + batch_size - results.append(batch_results[start_idx:end_idx]) - return results + if isinstance(batch, dict): + images = batch["images"] + sublist_lengths = batch["sublist_lengths"] + self.original_batch_structure.append(sublist_lengths) else: - return batch_results + sublist_lengths = None + images = batch + + self.model.eval() + with torch.no_grad(): + preds = self.model.forward(images) + return preds def predict_batch(self, images, preprocess_fn=None): """Predict a batch of images with the deepforest model. @@ -994,23 +1006,19 @@ def lr_lambda(epoch): else: return optimizer - def evaluate( + def __evaluate__( self, csv_file, iou_threshold=None, root_dir=None, - size=None, - batch_size=None, predictions=None, ): - """Compute intersection-over-union and precision/recall for a given - iou_threshold. + """Internal method to compute intersection-over-union and + precision/recall for a given iou_threshold. Args: csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label" iou_threshold: float [0,1] intersection-over-union threshold for true positive - batch_size: int, the batch size to use for prediction. If None, uses the batch size of the model. - size: int, the size to resize the images to. If None, no resizing is done. predictions: list of predictions to use for evaluation. If None, predictions are generated from the model. Returns: @@ -1018,6 +1026,8 @@ def evaluate( """ self.model.eval() if root_dir is None: + if self.config.validation.root_dir is None: + raise ValueError("root_dir must be specified if not provided in config") root_dir = self.config.validation.root_dir ground_df = utilities.read_file(csv_file, root_dir=root_dir) @@ -1026,7 +1036,8 @@ def evaluate( if predictions is None: # Get the predict dataloader and use predict_batch predictions = self.predict_file( - csv_file, ground_df.root_dir, size=size, batch_size=batch_size + csv_file, + root_dir, ) if iou_threshold is None: @@ -1047,6 +1058,42 @@ def evaluate( return results + def evaluate( + self, + csv_file, + iou_threshold=None, + root_dir=None, + predictions=None, + ): + """Compute intersection-over-union and precision/recall for a given + iou_threshold. + + .. deprecated:: 2.0.0 + This method is deprecated. Users should use `trainer.validate()` instead + to get evaluation statistics during training. This method will be removed + in a future version. + + Args: + csv_file: location of a csv file with columns "name","xmin","ymin","xmax","ymax","label" + iou_threshold: float [0,1] intersection-over-union threshold for true positive + predictions: list of predictions to use for evaluation. If None, predictions are generated from the model. + + Returns: + dict: Results dictionary containing precision, recall and other metrics + """ + warnings.warn( + "deepforest.evaluate() is deprecated and will be removed in a future version. " + "Please use trainer.validate() instead to get evaluation statistics during training.", + DeprecationWarning, + stacklevel=2, + ) + return self.__evaluate__( + csv_file=csv_file, + iou_threshold=iou_threshold, + root_dir=root_dir, + predictions=predictions, + ) + def __evaluation_logs__(self, results): """Log metrics from evaluation results.""" # Log metrics @@ -1066,7 +1113,7 @@ def __evaluation_logs__(self, results): pass # Log each key value pair of the results dict - if results["class_recall"] is not None: + if results["class_recall"] is not None and self.config.num_classes > 1: for key, value in results.items(): if key in ["class_recall"]: for _, row in value.iterrows(): diff --git a/src/deepforest/model.py b/src/deepforest/model.py index 95dd45196..f5addad21 100644 --- a/src/deepforest/model.py +++ b/src/deepforest/model.py @@ -407,10 +407,12 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): metric_dict = self.metrics.compute() - for index, value in enumerate(metric_dict["Class Accuracy"]): - key = self.numeric_to_label_dict[index] - metric_name = f"Class Accuracy_{key}" - self.log(metric_name, value, on_step=False, on_epoch=True) + # Only log per-class metrics when there are multiple classes + if len(self.numeric_to_label_dict) > 1: + for index, value in enumerate(metric_dict["Class Accuracy"]): + key = self.numeric_to_label_dict[index] + metric_name = f"Class Accuracy_{key}" + self.log(metric_name, value, on_step=False, on_epoch=True) self.log( "Micro-Average Accuracy", diff --git a/src/deepforest/predict.py b/src/deepforest/predict.py index f31e49b6d..1be09dae2 100644 --- a/src/deepforest/predict.py +++ b/src/deepforest/predict.py @@ -201,12 +201,20 @@ def _dataloader_wrapper_(model, trainer, dataloader, root_dir, crop_model): # Flatten list from batched prediction prediction_list = [] - for batch in batched_results: - for images in batch: - prediction_list.append(images) - - # Postprocess predictions - results = dataloader.dataset.postprocess(prediction_list) + global_image_idx = 0 + for _idx, batch in enumerate(batched_results): + for _image_idx, image_result in enumerate(batch): + formatted_result = dataloader.dataset.postprocess( + image_result, global_image_idx + ) + global_image_idx += 1 + prediction_list.append(formatted_result) + + # Postprocess predictions, return empty dataframe if no predictions + if not prediction_list: + return pd.DataFrame() + + results = pd.concat(prediction_list) if results.empty: return results diff --git a/tests/test_datasets_prediction.py b/tests/test_datasets_prediction.py index 2948ea93d..f98ab3f72 100644 --- a/tests/test_datasets_prediction.py +++ b/tests/test_datasets_prediction.py @@ -31,34 +31,12 @@ def test_SingleImage_path(): for i in range(len(ds)): assert ds.get_crop(i).shape == (3, 300, 300) -def test_invalid_array_dtype(): - # Not explicitly 8-bit - test_data = np.random.random((300,300, 3)) - with pytest.raises(ValueError): - SingleImage(image=test_data) - -def test_invalid_image_dtype(): - # Not explicitly 8-bit - test_data = (np.random.random((300,300)) * 255).astype(np.float32) - with pytest.raises(ValueError): - SingleImage(image=Image.fromarray(test_data)) - -def test_invalid_array_shape(): - # Not HWC - test_data = np.random.random((3,300,300)) - with pytest.raises(ValueError): - SingleImage(image=test_data) - def test_invalid_image_shape(): # Not 3 channels test_data = (np.random.rand(300, 300, 4) * 255).astype(np.uint8) - with pytest.raises(ValueError, match="4 channels"): + with pytest.raises(ValueError): SingleImage(image=Image.fromarray(test_data)) -def test_no_image_or_path(): - with pytest.raises(ValueError, match="image or image_path"): - SingleImage() - def test_valid_image(): # 8-bit, HWC test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8) @@ -69,19 +47,13 @@ def test_valid_array(): test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8) SingleImage(image=test_data) -def test_image_resize(): - # Resize + transpose - test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8) - ds = SingleImage(image=test_data) - assert ds._load_and_preprocess_image(image=test_data, size=200).shape == (3, 200, 200) - def test_MultiImage(): ds = MultiImage(paths=[get_data("OSBS_029.png"), get_data("OSBS_029.png")], patch_size=300, patch_overlap=0) - assert len(ds) == 2 # 2 windows each image 2 * 2 = 4 - assert ds[0].shape == (4, 3, 300, 300) + assert len(ds) == 2 + assert ds[0][0].shape == (3, 300, 300) def test_FromCSVFile(): ds = FromCSVFile(csv_file=get_data("example.csv"), diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 9d93bbb22..ca4fdc257 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -12,6 +12,7 @@ from deepforest import main from deepforest.utilities import read_file +from PIL import Image from shapely.geometry import box @@ -79,7 +80,7 @@ def test_evaluate_empty(m, tmp_path): m = main.deepforest(config_args={"model": {"name": None}, "log_root": str(tmp_path)}) csv_file = get_data("OSBS_029.csv") - results = m.evaluate(csv_file, iou_threshold=0.4) + results = m.evaluate(csv_file, iou_threshold=0.4, root_dir=os.path.dirname(csv_file)) # Does this make reasonable predictions, we know the model works. assert np.isnan(results["box_precision"]) @@ -202,3 +203,103 @@ def test_evaluate_boxes_no_predictions_for_image(): assert results["box_recall"] == 0 # No predictions for image1.jpg assert all(results["results"].match == False) # No matches assert all(results["results"].prediction_id.isna()) # No prediction IDs + + +def test_validate_predictions_match_predict_file(m): + """trainer.validate should populate predictions equivalent to main.predict_file.""" + csv_file = get_data("example.csv") + root_dir = os.path.dirname(csv_file) + + # Run validation to populate m.predictions + m.create_trainer() + m.trainer.validate(m) + + # Concatenate predictions collected during validation + if len(m.predictions) > 0: + val_preds = pd.concat(m.predictions, ignore_index=True) + else: + val_preds = pd.DataFrame( + columns=["image_path", "xmin", "ymin", "xmax", "ymax", "label"]) + + # Predict through inference path + infer_preds = m.predict_file(csv_file=csv_file, root_dir=root_dir) + + assert infer_preds.empty == val_preds.empty + assert infer_preds.shape == val_preds.shape + + # Compare the first row of the predictions + assert infer_preds.iloc[0].xmin == val_preds.iloc[0].xmin + assert infer_preds.iloc[0].ymin == val_preds.iloc[0].ymin + assert infer_preds.iloc[0].xmax == val_preds.iloc[0].xmax + assert infer_preds.iloc[0].ymax == val_preds.iloc[0].ymax + assert infer_preds.iloc[0].label == val_preds.iloc[0].label + + +def test_validate_predictions_match_predict_file_mixed_sizes(m, tmp_path): + """trainer.validate should populate predictions equivalent to main.predict_file for mixed-size images.""" + # Prepare two images at different sizes + src_path = get_data("OSBS_029.tif") + img = Image.open(src_path).convert("RGB") + + # Create a smaller and a larger variant (bounded to avoid extreme sizes) + w, h = img.size + small = img.resize((max(64, w // 2), max(64, h // 2))) + large = img.resize((min(w * 2, 2 * w), min(h * 2, 2 * h))) + + # Save both to tmp directory as PNGs + small_name = "mixed_small.png" + large_name = "mixed_large.png" + small_path = os.path.join(tmp_path, small_name) + large_path = os.path.join(tmp_path, large_name) + small.save(small_path) + large.save(large_path) + + # Build a CSV with empty annotations (0,0,0,0) for validation compatibility + # predict_file only needs image_path, but validation requires full annotation format + csv_path = os.path.join(tmp_path, "mixed_images.csv") + df = pd.DataFrame({ + "image_path": [small_name, large_name], + "xmin": [0, 0], + "ymin": [0, 0], + "xmax": [0, 0], + "ymax": [0, 0], + "label": ["Tree", "Tree"] + }) + df.to_csv(csv_path, index=False) + + # Configure validation to use the mixed-size images + m.config.validation.csv_file = csv_path + m.config.validation.root_dir = str(tmp_path) + # Don't set validation.size to avoid resizing - use batch_size=1 to handle mixed sizes + m.config.batch_size = 1 + + # Run validation to populate m.predictions + m.create_trainer() + m.trainer.validate(m) + + # Concatenate predictions collected during validation + if len(m.predictions) > 0: + val_preds = pd.concat(m.predictions, ignore_index=True) + else: + val_preds = pd.DataFrame( + columns=["image_path", "xmin", "ymin", "xmax", "ymax", "label"]) + + # Predict through inference path + # Don't pass size to avoid resizing - coordinates should be in original image space + infer_preds = m.predict_file(csv_file=csv_path, root_dir=str(tmp_path)) + + assert infer_preds.empty == val_preds.empty + assert infer_preds.shape == val_preds.shape + + # Compare predictions by image_path to ensure matching + # Sort both dataframes for comparison + val_preds_sorted = val_preds.sort_values(by=["image_path", "xmin", "ymin"]).reset_index(drop=True) + infer_preds_sorted = infer_preds.sort_values(by=["image_path", "xmin", "ymin"]).reset_index(drop=True) + + # Compare the first row of the predictions if both are non-empty + if not val_preds_sorted.empty and not infer_preds_sorted.empty: + assert infer_preds_sorted.iloc[0].xmin == val_preds_sorted.iloc[0].xmin + assert infer_preds_sorted.iloc[0].ymin == val_preds_sorted.iloc[0].ymin + assert infer_preds_sorted.iloc[0].xmax == val_preds_sorted.iloc[0].xmax + assert infer_preds_sorted.iloc[0].ymax == val_preds_sorted.iloc[0].ymax + assert infer_preds_sorted.iloc[0].label == val_preds_sorted.iloc[0].label diff --git a/tests/test_main.py b/tests/test_main.py index 75149e84b..3dabcdb15 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -443,7 +443,7 @@ def test_predict_dataloader(m, batch_size, path): ds = prediction.SingleImage(image=tile, path=path, patch_overlap=0.1, patch_size=100) dl = m.predict_dataloader(ds) batch = next(iter(dl)) - assert batch.shape[0] == batch_size + assert len(batch) == batch_size def test_predict_tile_empty(m_without_release, path): m = m_without_release @@ -467,7 +467,7 @@ def test_predict_tile(m, path, dataloader_strategy): prediction = m.predict_tile(path=image_path, patch_size=300, dataloader_strategy=dataloader_strategy, - patch_overlap=0.1) + patch_overlap=0) assert isinstance(prediction, pd.DataFrame) assert set(prediction.columns) == { @@ -958,7 +958,7 @@ def test_batch_inference_consistency(m, path): single_predictions = [] for image in ds: - image = np.rollaxis(image, 0, 3) * 255 + image = np.rollaxis(image.numpy(), 0, 3) * 255.0 single_prediction = m.predict_image(image=image) single_predictions.append(single_prediction) @@ -1138,3 +1138,36 @@ def test_set_labels_invalid_length(m): # Expect a ValueError when setting an inv invalid_mapping = {"Object": 0, "Extra": 1} with pytest.raises(ValueError): m.set_labels(invalid_mapping) + +def test_predict_file_mixed_sizes(m, tmp_path): + """Mixed-size images should yield predictions in original image coordinates.""" + # Prepare two images at different sizes + src_path = get_data("OSBS_029.tif") + img = Image.open(src_path).convert("RGB") + + # Create a smaller and a larger variant (bounded to avoid extreme sizes) + w, h = img.size + small = img.resize((max(64, w // 2), max(64, h // 2))) + large = img.resize((min(w * 2, 2 * w), min(h * 2, 2 * h))) + + # Save both to tmp directory as PNGs + small_name = "mixed_small.png" + large_name = "mixed_large.png" + small_path = os.path.join(tmp_path, small_name) + large_path = os.path.join(tmp_path, large_name) + small.save(small_path) + large.save(large_path) + + # Build a CSV with just image_path column (prediction path) + csv_path = os.path.join(tmp_path, "mixed_images.csv") + df = pd.DataFrame({"image_path": [small_name, large_name]}) + df["label"] = "Tree" + # Borrow the geometry from the OSBS_029.csv file + geometry = read_file(get_data("OSBS_029.csv"))["geometry"] + df["geometry"] = [geometry.iloc[0] for _ in range(len(df))] + df.to_csv(csv_path, index=False) + + m.config.validation.size = 200 + preds = m.predict_file(csv_file=csv_path, root_dir=str(tmp_path)) + + assert preds.ymax.max() > 200 # The larger image should have predictions outside the 200px limit