diff --git a/.gitignore b/.gitignore index 0d20b64..a86cb60 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,7 @@ *.pyc + +# checkpoint directories +checkpoint/* + +# data +data/* \ No newline at end of file diff --git a/geort/dataset.py b/geort/dataset.py index 2166387..25ea4af 100644 --- a/geort/dataset.py +++ b/geort/dataset.py @@ -47,7 +47,7 @@ def __init__(self, qpos_keypoint_file, keypoint_names): self.qpos = np_array["qpos"] self.keypoints = np_array["keypoint"].item() self.keypoint_names = keypoint_names - print("Keypoint Names", self.keypoint_names) + # print("Keypoint Names", self.keypoint_names) # Commented out to reduce verbose output self.n = len(self.qpos) return diff --git a/geort/export.py b/geort/export.py index 6b016c0..004374b 100644 --- a/geort/export.py +++ b/geort/export.py @@ -7,6 +7,7 @@ import torch import os from pathlib import Path +import json from geort.formatter import HandFormatter from geort.model import IKModel from geort.utils.path import to_package_root, get_checkpoint_root @@ -21,7 +22,7 @@ def __init__(self, model_path, config_path): config = load_json(config_path) keypoint_info = parse_config_keypoint_info(config) joint_lower_limit, joint_upper_limit = parse_config_joint_limit(config) - print(keypoint_info["joint"]) + # print(keypoint_info["joint"]) # Commented out to reduce verbose output self.human_ids = keypoint_info["human_id"] self.model = IKModel(keypoint_joints=keypoint_info["joint"]).cuda() self.model.load_state_dict(torch.load(model_path)) @@ -36,9 +37,14 @@ def forward(self, keypoints): return joint_raw[0] -def load_model(tag='', epoch=0): +def load_model(tag='', epoch=0, use_best=True): ''' - Loading API. + Loading API with best model preference. + + Args: + tag: checkpoint tag to search for + epoch: specific epoch to load (0 means latest) + use_best: if True, prefer best.pth over last.pth when epoch=0 ''' checkpoint_root = get_checkpoint_root() all_checkpoints = os.listdir(checkpoint_root) @@ -50,10 +56,21 @@ def load_model(tag='', epoch=0): break checkpoint_root = Path(checkpoint_root) / checkpoint_name + if epoch > 0: model_path = checkpoint_root / f"epoch_{epoch}.pth" else: - model_path = checkpoint_root / f"last.pth" + # Try to load best model first if use_best is True + if use_best and (checkpoint_root / "best.pth").exists(): + model_path = checkpoint_root / "best.pth" + # Load training history to show best model info + history_path = checkpoint_root / "training_history.json" + if history_path.exists(): + with open(history_path, 'r') as f: + history = json.load(f) + print(f"Loading best model from epoch {history['best_epoch']} with validation loss: {history['best_val_loss']:.4f}") + else: + model_path = checkpoint_root / "last.pth" config_path = checkpoint_root / "config.json" return GeoRTRetargetingModel(model_path=model_path, config_path=config_path)