From a0a94a703ad51acb3d764aae0976128c9b0004c1 Mon Sep 17 00:00:00 2001 From: Gray Date: Mon, 21 Jul 2025 17:07:37 +0800 Subject: [PATCH] Improves model loading with best checkpoint preference Enhances the model loading API to prioritize best-performing checkpoints over latest ones when available. Displays validation metrics for transparency when loading the best model. Reduces console verbosity by commenting out debug print statements that were cluttering output during normal operation. Updates gitignore to exclude checkpoint and data directories from version control. --- .gitignore | 6 ++++++ geort/dataset.py | 2 +- geort/export.py | 25 +++++++++++++++++++++---- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 0d20b64..a86cb60 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,7 @@ *.pyc + +# checkpoint directories +checkpoint/* + +# data +data/* \ No newline at end of file diff --git a/geort/dataset.py b/geort/dataset.py index 2166387..25ea4af 100644 --- a/geort/dataset.py +++ b/geort/dataset.py @@ -47,7 +47,7 @@ def __init__(self, qpos_keypoint_file, keypoint_names): self.qpos = np_array["qpos"] self.keypoints = np_array["keypoint"].item() self.keypoint_names = keypoint_names - print("Keypoint Names", self.keypoint_names) + # print("Keypoint Names", self.keypoint_names) # Commented out to reduce verbose output self.n = len(self.qpos) return diff --git a/geort/export.py b/geort/export.py index 6b016c0..004374b 100644 --- a/geort/export.py +++ b/geort/export.py @@ -7,6 +7,7 @@ import torch import os from pathlib import Path +import json from geort.formatter import HandFormatter from geort.model import IKModel from geort.utils.path import to_package_root, get_checkpoint_root @@ -21,7 +22,7 @@ def __init__(self, model_path, config_path): config = load_json(config_path) keypoint_info = parse_config_keypoint_info(config) joint_lower_limit, joint_upper_limit = parse_config_joint_limit(config) - print(keypoint_info["joint"]) + # print(keypoint_info["joint"]) # Commented out to reduce verbose output self.human_ids = keypoint_info["human_id"] self.model = IKModel(keypoint_joints=keypoint_info["joint"]).cuda() self.model.load_state_dict(torch.load(model_path)) @@ -36,9 +37,14 @@ def forward(self, keypoints): return joint_raw[0] -def load_model(tag='', epoch=0): +def load_model(tag='', epoch=0, use_best=True): ''' - Loading API. + Loading API with best model preference. + + Args: + tag: checkpoint tag to search for + epoch: specific epoch to load (0 means latest) + use_best: if True, prefer best.pth over last.pth when epoch=0 ''' checkpoint_root = get_checkpoint_root() all_checkpoints = os.listdir(checkpoint_root) @@ -50,10 +56,21 @@ def load_model(tag='', epoch=0): break checkpoint_root = Path(checkpoint_root) / checkpoint_name + if epoch > 0: model_path = checkpoint_root / f"epoch_{epoch}.pth" else: - model_path = checkpoint_root / f"last.pth" + # Try to load best model first if use_best is True + if use_best and (checkpoint_root / "best.pth").exists(): + model_path = checkpoint_root / "best.pth" + # Load training history to show best model info + history_path = checkpoint_root / "training_history.json" + if history_path.exists(): + with open(history_path, 'r') as f: + history = json.load(f) + print(f"Loading best model from epoch {history['best_epoch']} with validation loss: {history['best_val_loss']:.4f}") + else: + model_path = checkpoint_root / "last.pth" config_path = checkpoint_root / "config.json" return GeoRTRetargetingModel(model_path=model_path, config_path=config_path)