From bfb87df83ad850e308c8ebec42700b9ec90efa60 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Thu, 29 Jan 2026 15:38:00 -0500 Subject: [PATCH] fix: gather validation predictions from all ranks in DDP In on_validation_epoch_end, when world_size > 1, gather predictions from all ranks via torch.distributed.all_gather_object before building the predictions DataFrame and calling __evaluate__. This ensures box_recall, box_precision, etc. are computed on the full validation set (matching 1-GPU behavior) instead of only one rank's subset. See docs/multi_gpu_validation_fix.md for problem description and verification scripts. --- src/deepforest/main.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 45de757ea..1101d7f68 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -853,8 +853,23 @@ def on_validation_epoch_end(self): self.log_epoch_metrics() if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0: - if len(self.predictions) > 0: - predictions = pd.concat(self.predictions) + # In DDP, each rank only has its share of validation batches in + # self.predictions. Gather from all ranks so metrics use the full set. + if self.trainer.world_size > 1 and torch.distributed.is_initialized(): + object_list = [None] * self.trainer.world_size + torch.distributed.all_gather_object(object_list, self.predictions) + all_predictions = [ + df + for rank_list in object_list + for df in (rank_list if rank_list is not None else []) + ] + predictions = ( + pd.concat(all_predictions, ignore_index=True) + if all_predictions + else pd.DataFrame() + ) + elif len(self.predictions) > 0: + predictions = pd.concat(self.predictions, ignore_index=True) else: predictions = pd.DataFrame()