Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 95 additions & 27 deletions topobench/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Main entry point for training and testing models."""

import random
from pathlib import Path
from typing import Any

import hydra
Expand All @@ -9,7 +10,9 @@
import rootutils
import torch
from lightning import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import Logger
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import DictConfig, OmegaConf

from topobench.data.preprocessor import PreProcessor
Expand Down Expand Up @@ -227,39 +230,104 @@ def run(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:

if cfg.get("test"):
log.info("Starting testing!")
test_best_model_path = True
if cfg.get("ckpt_path"):
ckpt_path = cfg.ckpt_path

rerun_best_model_checkpoint(
checkpoint_model=model,
cfg=cfg,
datamodule=datamodule,
device=model.device,
callbacks=callbacks,
logger=logger,
)

# Merge train and test metrics
metric_dict = {**train_metrics}

return metric_dict, object_dict


def rerun_best_model_checkpoint(
checkpoint_model: LightningModule,
cfg: DictConfig,
datamodule: LightningModule,
device: torch.device,
callbacks: list[Callback],
logger: list[Logger],
) -> None:
"""Rerun the best model checkpoint on validation and test datasets to log final metrics.

This function iterates through the callbacks to locate the `ModelCheckpoint`, loads the
best model weights, and runs a test pass on both the validation and test dataloaders.
Metrics are logged with `val_best_rerun/` and `test_best_rerun/` prefixes to ensure
metrics reflect the best model state rather than the final epoch.

Parameters
----------
checkpoint_model : LightningModule
The model instance to load weights into.
cfg : DictConfig
Configuration composed by Hydra.
datamodule : LightningModule
The data module providing `val_dataloader` and `test_dataloader`.
device : torch.device
The target device (CPU/GPU) for the model.
callbacks : list[Callback]
A list of callbacks to search for the `ModelCheckpoint`.
logger : list[Logger]
A list of loggers (e.g., WandbLogger) to record the re-run metrics.
"""
for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
log.info(
f"Attempting to load weights from the provided ckpt_path: {ckpt_path}"
f"Loading best model from checkpoint at {callback.best_model_path}"
)
try:
trainer.test(
model=model, datamodule=datamodule, ckpt_path=ckpt_path
)
test_best_model_path = False # do not test "best model" if a valid ckpt_path is provided
except FileNotFoundError:
log.warning(
f"No checkpoint file found at the provided ckpt_path: {ckpt_path}."
)
log.info("Trying with best model instead...")
if test_best_model_path:
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning(
"Best ckpt not found! Using current weights for testing..."
)
ckpt_path = None
trainer.test(
model=model, datamodule=datamodule, ckpt_path=ckpt_path
model_path = Path(callback.best_model_path)
ckpt = torch.load(
model_path, map_location="cpu", weights_only=False
)

test_metrics = trainer.callback_metrics
checkpoint_model.load_state_dict(ckpt["state_dict"], strict=True)
checkpoint_model.to(device)
break # there is only one checkpoint callback

# Merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
# New trainer to log final metrics on validation set
# Because wandb displays validation metrics from the final, not the best epoch.
checkpoint_trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
num_sanity_val_steps=0,
enable_progress_bar=cfg.get("enable_progress_bar", True),
logger=False,
)

return metric_dict, object_dict
log.info("Re-testing best model checkpoint on validation set!")
val_loader = datamodule.val_dataloader()
results = checkpoint_trainer.test(
model=checkpoint_model, dataloaders=val_loader
)
if results:
logged = {}
for k, v in results[0].items():
suffix = k.split("/", 1)[1] if "/" in k else k
logged[f"val_best_rerun/{suffix}"] = v
log.info(logged)
for lgr in logger:
if isinstance(lgr, WandbLogger):
lgr.log_metrics(logged)

log.info("Re-testing best model checkpoint on test set!")
test_loader = datamodule.test_dataloader()
results = checkpoint_trainer.test(
model=checkpoint_model, dataloaders=test_loader
)
if results:
logged = {}
for k, v in results[0].items():
suffix = k.split("/", 1)[1] if "/" in k else k
logged[f"test_best_rerun/{suffix}"] = v
log.info(logged)
for lgr in logger:
if isinstance(lgr, WandbLogger):
lgr.log_metrics(logged)


def count_number_of_parameters(
Expand Down
Loading