From aba78b2dd529ad6364ee4f6d8ec1c457739e6d2d Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Fri, 26 Dec 2025 03:23:41 -0800 Subject: [PATCH] updated the testing logic --- topobench/run.py | 122 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 95 insertions(+), 27 deletions(-) diff --git a/topobench/run.py b/topobench/run.py index ab6f8602..142ea2cc 100755 --- a/topobench/run.py +++ b/topobench/run.py @@ -1,6 +1,7 @@ """Main entry point for training and testing models.""" import random +from pathlib import Path from typing import Any import hydra @@ -9,7 +10,9 @@ import rootutils import torch from lightning import Callback, LightningModule, Trainer +from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.loggers import Logger +from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf from topobench.data.preprocessor import PreProcessor @@ -227,39 +230,104 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: if cfg.get("test"): log.info("Starting testing!") - test_best_model_path = True - if cfg.get("ckpt_path"): - ckpt_path = cfg.ckpt_path + + rerun_best_model_checkpoint( + checkpoint_model=model, + cfg=cfg, + datamodule=datamodule, + device=model.device, + callbacks=callbacks, + logger=logger, + ) + + # Merge train and test metrics + metric_dict = {**train_metrics} + + return metric_dict, object_dict + + +def rerun_best_model_checkpoint( + checkpoint_model: LightningModule, + cfg: DictConfig, + datamodule: LightningModule, + device: torch.device, + callbacks: list[Callback], + logger: list[Logger], +) -> None: + """Rerun the best model checkpoint on validation and test datasets to log final metrics. + + This function iterates through the callbacks to locate the `ModelCheckpoint`, loads the + best model weights, and runs a test pass on both the validation and test dataloaders. + Metrics are logged with `val_best_rerun/` and `test_best_rerun/` prefixes to ensure + metrics reflect the best model state rather than the final epoch. + + Parameters + ---------- + checkpoint_model : LightningModule + The model instance to load weights into. + cfg : DictConfig + Configuration composed by Hydra. + datamodule : LightningModule + The data module providing `val_dataloader` and `test_dataloader`. + device : torch.device + The target device (CPU/GPU) for the model. + callbacks : list[Callback] + A list of callbacks to search for the `ModelCheckpoint`. + logger : list[Logger] + A list of loggers (e.g., WandbLogger) to record the re-run metrics. + """ + for callback in callbacks: + if isinstance(callback, ModelCheckpoint): log.info( - f"Attempting to load weights from the provided ckpt_path: {ckpt_path}" + f"Loading best model from checkpoint at {callback.best_model_path}" ) - try: - trainer.test( - model=model, datamodule=datamodule, ckpt_path=ckpt_path - ) - test_best_model_path = False # do not test "best model" if a valid ckpt_path is provided - except FileNotFoundError: - log.warning( - f"No checkpoint file found at the provided ckpt_path: {ckpt_path}." - ) - log.info("Trying with best model instead...") - if test_best_model_path: - ckpt_path = trainer.checkpoint_callback.best_model_path - if ckpt_path == "": - log.warning( - "Best ckpt not found! Using current weights for testing..." - ) - ckpt_path = None - trainer.test( - model=model, datamodule=datamodule, ckpt_path=ckpt_path + model_path = Path(callback.best_model_path) + ckpt = torch.load( + model_path, map_location="cpu", weights_only=False ) - test_metrics = trainer.callback_metrics + checkpoint_model.load_state_dict(ckpt["state_dict"], strict=True) + checkpoint_model.to(device) + break # there is only one checkpoint callback - # Merge train and test metrics - metric_dict = {**train_metrics, **test_metrics} + # New trainer to log final metrics on validation set + # Because wandb displays validation metrics from the final, not the best epoch. + checkpoint_trainer: Trainer = hydra.utils.instantiate( + cfg.trainer, + num_sanity_val_steps=0, + enable_progress_bar=cfg.get("enable_progress_bar", True), + logger=False, + ) - return metric_dict, object_dict + log.info("Re-testing best model checkpoint on validation set!") + val_loader = datamodule.val_dataloader() + results = checkpoint_trainer.test( + model=checkpoint_model, dataloaders=val_loader + ) + if results: + logged = {} + for k, v in results[0].items(): + suffix = k.split("/", 1)[1] if "/" in k else k + logged[f"val_best_rerun/{suffix}"] = v + log.info(logged) + for lgr in logger: + if isinstance(lgr, WandbLogger): + lgr.log_metrics(logged) + + log.info("Re-testing best model checkpoint on test set!") + test_loader = datamodule.test_dataloader() + results = checkpoint_trainer.test( + model=checkpoint_model, dataloaders=test_loader + ) + if results: + logged = {} + for k, v in results[0].items(): + suffix = k.split("/", 1)[1] if "/" in k else k + logged[f"test_best_rerun/{suffix}"] = v + log.info(logged) + for lgr in logger: + if isinstance(lgr, WandbLogger): + lgr.log_metrics(logged) def count_number_of_parameters(