diff --git a/lib/vggflow/trainer.py b/lib/vggflow/trainer.py index 7194a46..c560d73 100644 --- a/lib/vggflow/trainer.py +++ b/lib/vggflow/trainer.py @@ -320,7 +320,7 @@ def _aggregate_and_log_metrics(self, global_step, num_processes): aggregated_info[k] = torch.max(torch.stack(v)) elif '_all' in k: aggregated_info[k] = torch.stack(v) - cache[k] = [torch.zeros_like(aggregated_info[k])] * num_processes + cache[k] = [torch.zeros_like(aggregated_info[k]) for _ in range(num_processes)] else: aggregated_info[k] = torch.mean(torch.stack(v)) @@ -334,45 +334,49 @@ def _aggregate_and_log_metrics(self, global_step, num_processes): dist.all_reduce(v, op=dist.ReduceOp.MAX) elif '_median' in k: dist.all_gather(cache[k], v) - new_info[k.replace('_all', '')] = torch.median(cache[k][self.local_rank]) + gathered = torch.cat([t.reshape(-1) for t in cache[k]], dim=0) + new_info[k.replace('_all', '')] = torch.median(gathered) elif '_08quantile' in k: dist.all_gather(cache[k], v) - new_info[k.replace('_all', '')] = torch.quantile(cache[k][self.local_rank], 0.8) + gathered = torch.cat([t.reshape(-1) for t in cache[k]], dim=0) + new_info[k.replace('_all', '')] = torch.quantile(gathered, 0.8) elif '_std' in k: dist.all_gather(cache[k], v) - new_info[k.replace('_all', '')] = torch.std(cache[k][self.local_rank]) + gathered = torch.cat([t.reshape(-1) for t in cache[k]], dim=0) + new_info[k.replace('_all', '')] = torch.std(gathered) else: dist.all_reduce(v, op=dist.ReduceOp.SUM) - # Remove '_all' keys and add computed statistics - for k in list(aggregated_info.keys()): + # Average reduced means across processes; keep min/max as-is; drop raw "_all" tensors. + final_info = {} + for k, v in aggregated_info.items(): if '_all' in k: - aggregated_info.pop(k, None) - aggregated_info.update(new_info) + continue + if '_min' in k or '_max' in k: + final_info[k] = v + else: + final_info[k] = v / num_processes - # Average across processes - aggregated_info = { - k: v / num_processes if ('_min' not in k and '_max' not in k) else v - for k, v in aggregated_info.items() - } + # Add computed statistics (already global across all ranks; must not be divided) + final_info.update(new_info) # Update rgrad threshold - rgrad_threshold = aggregated_info.get('rgrad_08quantile', 1.0).item() + rgrad_threshold = final_info.get('rgrad_08quantile', 1.0).item() # Add epoch and step info if self._is_main_process(): if self.scaler: - aggregated_info["grad_scale"] = self.scaler.get_scale() - aggregated_info["global_step"] = float(global_step) + final_info["grad_scale"] = self.scaler.get_scale() + final_info["global_step"] = float(global_step) # Log to wandb if self.config.logging.use_wandb: import wandb - wandb.log(aggregated_info, step=global_step) + wandb.log(final_info, step=global_step) # Log to console self.logger.info(f"global_step={global_step} " + - " ".join([f"{k}={v:.6f}" for k, v in aggregated_info.items()])) + " ".join([f"{k}={v:.6f}" for k, v in final_info.items()])) # Reset info dict self.info = defaultdict(list)