diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 45de757e..1101d7f6 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -853,8 +853,23 @@ def on_validation_epoch_end(self): self.log_epoch_metrics() if (self.current_epoch + 1) % self.config.validation.val_accuracy_interval == 0: - if len(self.predictions) > 0: - predictions = pd.concat(self.predictions) + # In DDP, each rank only has its share of validation batches in + # self.predictions. Gather from all ranks so metrics use the full set. + if self.trainer.world_size > 1 and torch.distributed.is_initialized(): + object_list = [None] * self.trainer.world_size + torch.distributed.all_gather_object(object_list, self.predictions) + all_predictions = [ + df + for rank_list in object_list + for df in (rank_list if rank_list is not None else []) + ] + predictions = ( + pd.concat(all_predictions, ignore_index=True) + if all_predictions + else pd.DataFrame() + ) + elif len(self.predictions) > 0: + predictions = pd.concat(self.predictions, ignore_index=True) else: predictions = pd.DataFrame()