From 05d2a72f9535c598ad7f1634e3f1206f7b039941 Mon Sep 17 00:00:00 2001 From: HUAI GUOWEI <54146923+gray-wei@users.noreply.github.com> Date: Mon, 21 Jul 2025 19:08:48 +0800 Subject: [PATCH] Revert "Update model loading and add new features" --- .gitignore | 6 ------ geort/dataset.py | 2 +- geort/export.py | 25 ++++--------------------- 3 files changed, 5 insertions(+), 28 deletions(-) diff --git a/.gitignore b/.gitignore index a86cb60..0d20b64 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1 @@ *.pyc - -# checkpoint directories -checkpoint/* - -# data -data/* \ No newline at end of file diff --git a/geort/dataset.py b/geort/dataset.py index 25ea4af..2166387 100644 --- a/geort/dataset.py +++ b/geort/dataset.py @@ -47,7 +47,7 @@ def __init__(self, qpos_keypoint_file, keypoint_names): self.qpos = np_array["qpos"] self.keypoints = np_array["keypoint"].item() self.keypoint_names = keypoint_names - # print("Keypoint Names", self.keypoint_names) # Commented out to reduce verbose output + print("Keypoint Names", self.keypoint_names) self.n = len(self.qpos) return diff --git a/geort/export.py b/geort/export.py index 004374b..6b016c0 100644 --- a/geort/export.py +++ b/geort/export.py @@ -7,7 +7,6 @@ import torch import os from pathlib import Path -import json from geort.formatter import HandFormatter from geort.model import IKModel from geort.utils.path import to_package_root, get_checkpoint_root @@ -22,7 +21,7 @@ def __init__(self, model_path, config_path): config = load_json(config_path) keypoint_info = parse_config_keypoint_info(config) joint_lower_limit, joint_upper_limit = parse_config_joint_limit(config) - # print(keypoint_info["joint"]) # Commented out to reduce verbose output + print(keypoint_info["joint"]) self.human_ids = keypoint_info["human_id"] self.model = IKModel(keypoint_joints=keypoint_info["joint"]).cuda() self.model.load_state_dict(torch.load(model_path)) @@ -37,14 +36,9 @@ def forward(self, keypoints): return joint_raw[0] -def load_model(tag='', epoch=0, use_best=True): +def load_model(tag='', epoch=0): ''' - Loading API with best model preference. - - Args: - tag: checkpoint tag to search for - epoch: specific epoch to load (0 means latest) - use_best: if True, prefer best.pth over last.pth when epoch=0 + Loading API. ''' checkpoint_root = get_checkpoint_root() all_checkpoints = os.listdir(checkpoint_root) @@ -56,21 +50,10 @@ def load_model(tag='', epoch=0, use_best=True): break checkpoint_root = Path(checkpoint_root) / checkpoint_name - if epoch > 0: model_path = checkpoint_root / f"epoch_{epoch}.pth" else: - # Try to load best model first if use_best is True - if use_best and (checkpoint_root / "best.pth").exists(): - model_path = checkpoint_root / "best.pth" - # Load training history to show best model info - history_path = checkpoint_root / "training_history.json" - if history_path.exists(): - with open(history_path, 'r') as f: - history = json.load(f) - print(f"Loading best model from epoch {history['best_epoch']} with validation loss: {history['best_val_loss']:.4f}") - else: - model_path = checkpoint_root / "last.pth" + model_path = checkpoint_root / f"last.pth" config_path = checkpoint_root / "config.json" return GeoRTRetargetingModel(model_path=model_path, config_path=config_path)