From b33339d7b2947eb564716aa34be687109e3856f4 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Thu, 24 Oct 2024 07:56:03 -0700 Subject: [PATCH 1/8] use aria hand tracking as the detector --- 2_run_hamer_on_vrs.py | 236 +++++++++++++++++++++++++++++++++++++++++- aria_utils.py | 211 +++++++++++++++++++++++++++++++++++++ 2 files changed, 446 insertions(+), 1 deletion(-) create mode 100755 aria_utils.py diff --git a/2_run_hamer_on_vrs.py b/2_run_hamer_on_vrs.py index fc0b2e3..afbc721 100644 --- a/2_run_hamer_on_vrs.py +++ b/2_run_hamer_on_vrs.py @@ -18,9 +18,12 @@ VrsDataProvider, create_vrs_data_provider, ) +import projectaria_tools.core.mps as mps +from projectaria_tools.core.sensor_data import DEVICE_TIME from tqdm.auto import tqdm from egoallo.inference_utils import InferenceTrajectoryPaths +from aria_utils import per_image_hand_tracking,get_online_calib,x_y_around def main(traj_root: Path, overwrite: bool = False) -> None: @@ -39,7 +42,10 @@ def main(traj_root: Path, overwrite: bool = False) -> None: assert vrs_path.exists() pickle_out = traj_root / "hamer_outputs.pkl" hamer_render_out = traj_root / "hamer_outputs_render" # This is just for debugging. - run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) + wrist_and_palm_poses_path = traj_root / "hand_tracking/wrist_and_palm_poses.csv" + online_path = traj_root / "slam/online_calibration.jsonl" + # run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) + run_aria_hamer_and_save(vrs_path, pickle_out, hamer_render_out, wrist_and_palm_poses_path, online_path, overwrite) def run_hamer_and_save( @@ -200,6 +206,234 @@ def run_hamer_and_save( pickle.dump(outputs, f) +def run_aria_hamer_and_save( + vrs_path: Path, pickle_out: Path, hamer_render_out: Path, wrist_and_palm_poses_path: Path, online_calib_path: Path, overwrite: bool +) -> None: + if not overwrite: + assert not pickle_out.exists() + assert not hamer_render_out.exists() + else: + pickle_out.unlink(missing_ok=True) + shutil.rmtree(hamer_render_out, ignore_errors=True) + + hamer_render_out.mkdir(exist_ok=True) + hamer_helper = HamerHelper() + + # VRS data provider setup. + provider = create_vrs_data_provider(str(vrs_path.absolute())) + assert isinstance(provider, VrsDataProvider) + rgb_stream_id = provider.get_stream_id_from_label("camera-rgb") + assert rgb_stream_id is not None + + num_images = provider.get_num_data(rgb_stream_id) + print(f"Found {num_images=}") + + # Get calibrations. + device_calib = provider.get_device_calibration() + assert device_calib is not None + camera_calib = device_calib.get_camera_calib("camera-rgb") + assert camera_calib is not None + pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450) + + # Compute camera extrinsics! + sophus_T_device_camera = device_calib.get_transform_device_sensor("camera-rgb") + sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor("camera-rgb") + assert sophus_T_device_camera is not None + assert sophus_T_cpf_camera is not None + T_device_cam = np.concatenate( + [ + sophus_T_device_camera.rotation().to_quat().squeeze(axis=0), + sophus_T_device_camera.translation().squeeze(axis=0), + ] + ) + T_cpf_cam = np.concatenate( + [ + sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0), + sophus_T_cpf_camera.translation().squeeze(axis=0), + ] + ) + assert T_device_cam.shape == T_cpf_cam.shape == (7,) + + # Dict from capture timestamp in nanoseconds to fields we care about. + detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} + detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} + + wrist_and_palm_poses_path = str(wrist_and_palm_poses_path) + online_calib_path = str(online_calib_path) + + rgb_calib = get_online_calib(online_calib_path, "camera-rgb") + wrist_and_palm_poses = mps.hand_tracking.read_wrist_and_palm_poses(wrist_and_palm_poses_path) + + pbar = tqdm(range(num_images)) + + l_point_queue=[] + r_point_queue=[] + queue_length=5 + + for i in pbar: + image_data, image_data_record = provider.get_image_data_by_index( + rgb_stream_id, i + ) + undistorted_image = calibration.distort_by_calibration( + image_data.to_numpy_array(), pinhole, camera_calib + ) + + timestamp_ns = image_data_record.capture_timestamp_ns + l_existed, r_existed, l_point, r_point = per_image_hand_tracking(timestamp_ns, wrist_and_palm_poses, pinhole, camera_calib, rgb_calib) + if l_existed: + l_box = x_y_around(l_point[0], l_point[1],pinhole,offset=80) + l_point_queue.append(l_point) + else: + l_box=None + # for index_l1 in range(len(l_point_queue)-1,-1,-1): + # if l_point_queue[index_l1] is not None: + # for index_l2 in range(index_l1-1,-1,-1): + # if l_point_queue[index_l2] is not None: + # l_point = (len(l_point_queue)-index_l1)*(l_point_queue[index_l1]-l_point_queue[index_l2])/(index_l1-index_l2)+l_point_queue[index_l1] + # l_box = x_y_around(l_point[0], l_point[1],pinhole) + # l_existed=True + # # print("use previous l:",len(l_point_queue)-index_l1,len(l_point_queue)-index_l2) + # break + # if l_existed: + # break + l_point_queue.append(None) + + if r_existed: + r_box = x_y_around(r_point[0], r_point[1],pinhole,offset=80) + r_point_queue.append(r_point) + else: + r_box=None + # for index_r1 in range(len(r_point_queue)-1,-1,-1): + # if r_point_queue[index_r1] is not None: + # for index_r2 in range(index_r1-1,-1,-1): + # if r_point_queue[index_r2] is not None: + # r_point = (len(r_point_queue)-index_r1)*(r_point_queue[index_r1]-r_point_queue[index_r2])/(index_r1-index_r2)+r_point_queue[index_r1] + # r_box = x_y_around(r_point[0], r_point[1],pinhole) + # r_existed=True + # # print("use previous r:",len(r_point_queue)-index_r1,len(r_point_queue)-index_r2) + # break + # if r_existed: + # break + r_point_queue.append(None) + + if len(l_point_queue)>queue_length: + l_point_queue.pop(0) + r_point_queue.pop(0) + + hamer_out_left, hamer_out_right = hamer_helper.get_det_from_boxes( + undistorted_image, + l_existed, + r_existed, + l_box, + r_box, + focal_length=450, + ) + + if hamer_out_left is None: + detections_left_wrt_cam[timestamp_ns] = None + else: + detections_left_wrt_cam[timestamp_ns] = { + "verts": hamer_out_left["verts"], + "keypoints_3d": hamer_out_left["keypoints_3d"], + "mano_hand_pose": hamer_out_left["mano_hand_pose"], + "mano_hand_betas": hamer_out_left["mano_hand_betas"], + "mano_hand_global_orient": hamer_out_left["mano_hand_global_orient"], + } + + if hamer_out_right is None: + detections_right_wrt_cam[timestamp_ns] = None + else: + detections_right_wrt_cam[timestamp_ns] = { + "verts": hamer_out_right["verts"], + "keypoints_3d": hamer_out_right["keypoints_3d"], + "mano_hand_pose": hamer_out_right["mano_hand_pose"], + "mano_hand_betas": hamer_out_right["mano_hand_betas"], + "mano_hand_global_orient": hamer_out_right["mano_hand_global_orient"], + } + + composited = undistorted_image + composited = hamer_helper.composite_detections( + composited, + hamer_out_left, + border_color=(255, 100, 100), + focal_length=450, + ) + composited = hamer_helper.composite_detections( + composited, + hamer_out_right, + border_color=(100, 100, 255), + focal_length=450, + ) + composited = put_text( + composited, + "L detections: " + + ( + "0" if hamer_out_left is None else str(hamer_out_left["verts"].shape[0]) + ), + 0, + color=(255, 100, 100), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + composited = put_text( + composited, + "R detections: " + + ( + "0" + if hamer_out_right is None + else str(hamer_out_right["verts"].shape[0]) + ), + 1, + color=(100, 100, 255), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + composited = put_text( + composited, + f"ns={timestamp_ns}", + 2, + color=(255, 255, 255), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + + print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}") + # bbox on undistorted image + if l_existed: + min_l_p_x, min_l_p_y, max_l_p_x, max_l_p_y = l_box + max_l_p_x, min_l_p_x, max_l_p_y, min_l_p_y = int(max_l_p_x), int(min_l_p_x), int(max_l_p_y), int(min_l_p_y) + + cv2.rectangle(composited, (max_l_p_x, max_l_p_y), (min_l_p_x, min_l_p_y), (255, 100, 100),2) + if r_existed: + min_r_p_x, min_r_p_y, max_r_p_x, max_r_p_y = r_box + max_r_p_x, min_r_p_x, max_r_p_y, min_r_p_y = int(max_r_p_x), int(min_r_p_x), int(max_r_p_y), int(min_r_p_y) + + cv2.rectangle(composited, (max_r_p_x, max_r_p_y), (min_r_p_x, min_r_p_y), (100, 100, 255),2) + + + iio.imwrite( + str(hamer_render_out / f"{i:06d}.jpeg"), + np.concatenate( + [ + # Darken input image, just for contrast... + (undistorted_image * 0.6).astype(np.uint8), + composited, + ], + axis=1, + ), + quality=90, + ) + + outputs = SavedHamerOutputs( + mano_faces_right=hamer_helper.get_mano_faces("right"), + mano_faces_left=hamer_helper.get_mano_faces("left"), + detections_right_wrt_cam=detections_right_wrt_cam, + detections_left_wrt_cam=detections_left_wrt_cam, + T_device_cam=T_device_cam, + T_cpf_cam=T_cpf_cam, + ) + with open(pickle_out, "wb") as f: + pickle.dump(outputs, f) + + + def put_text( image: np.ndarray, text: str, diff --git a/aria_utils.py b/aria_utils.py new file mode 100755 index 0000000..50f827d --- /dev/null +++ b/aria_utils.py @@ -0,0 +1,211 @@ +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import projectaria_tools.core.mps as mps +from projectaria_tools.core import data_provider,calibration +# from projectaria_tools.core.mps.utils import get_nearest_pose +from projectaria_tools.core.stream_id import StreamId +from projectaria_tools.core.sensor_data import DEVICE_TIME,CLOSEST +from projectaria_tools.core.mps.utils import get_nearest_wrist_and_palm_pose +import os + +def x_y_rot90(x, y, w, h): + return h-y, x + +def x_y_undistort(x, y, w, h, pinhole, calib): + x, y = int(x), int(y) + # a zero numpy array of shape (w, h) with coordinate (x,y) having value 1 + point = np.zeros((w, h)) + offset = 2 + for i in range(x-offset, x+offset+1): + for j in range(y-offset, y+offset+1): + if i>=0 and i=0 and j Date: Sun, 27 Oct 2024 08:01:01 -0700 Subject: [PATCH 2/8] aria_detector --- 2_run_hamer_on_vrs.py | 172 ++++++++- _hamer_helper.py | 830 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 999 insertions(+), 3 deletions(-) create mode 100644 _hamer_helper.py diff --git a/2_run_hamer_on_vrs.py b/2_run_hamer_on_vrs.py index afbc721..0c935f2 100644 --- a/2_run_hamer_on_vrs.py +++ b/2_run_hamer_on_vrs.py @@ -12,7 +12,7 @@ SavedHamerOutputs, SingleHandHamerOutputWrtCamera, ) -from hamer_helper import HamerHelper +from _hamer_helper import HamerHelper from projectaria_tools.core import calibration from projectaria_tools.core.data_provider import ( VrsDataProvider, @@ -26,13 +26,14 @@ from aria_utils import per_image_hand_tracking,get_online_calib,x_y_around -def main(traj_root: Path, overwrite: bool = False) -> None: +def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False) -> None: """Run HaMeR for on trajectory. We'll save outputs to `traj_root/hamer_outputs.pkl` and `traj_root/hamer_outputs_render". Arguments: traj_root: The root directory of the trajectory. We assume that there's a VRS file in this directory. + detector: The detector to use. Can be "WiLoR", "aria", or "hamer". overwrite: If True, overwrite any existing HaMeR outputs. """ @@ -45,7 +46,14 @@ def main(traj_root: Path, overwrite: bool = False) -> None: wrist_and_palm_poses_path = traj_root / "hand_tracking/wrist_and_palm_poses.csv" online_path = traj_root / "slam/online_calibration.jsonl" # run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) - run_aria_hamer_and_save(vrs_path, pickle_out, hamer_render_out, wrist_and_palm_poses_path, online_path, overwrite) + if detector == "WiLoR": + run_wilor_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) + elif detector == "aria": + run_aria_hamer_and_save(vrs_path, pickle_out, hamer_render_out, wrist_and_palm_poses_path, online_path, overwrite) + elif detector == "hamer": + run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) + else: + raise ValueError(f"Unknown detector: {detector}") def run_hamer_and_save( @@ -205,6 +213,164 @@ def run_hamer_and_save( with open(pickle_out, "wb") as f: pickle.dump(outputs, f) +def run_wilor_and_save( + vrs_path: Path, pickle_out: Path, hamer_render_out: Path, overwrite: bool +) -> None: + raise NotImplementedError("WiLoR is not implemented yet.") + if not overwrite: + assert not pickle_out.exists() + assert not hamer_render_out.exists() + else: + pickle_out.unlink(missing_ok=True) + shutil.rmtree(hamer_render_out, ignore_errors=True) + + hamer_render_out.mkdir(exist_ok=True) + hamer_helper = HamerHelper() + + # VRS data provider setup. + provider = create_vrs_data_provider(str(vrs_path.absolute())) + assert isinstance(provider, VrsDataProvider) + rgb_stream_id = provider.get_stream_id_from_label("camera-rgb") + assert rgb_stream_id is not None + + num_images = provider.get_num_data(rgb_stream_id) + print(f"Found {num_images=}") + + # Get calibrations. + device_calib = provider.get_device_calibration() + assert device_calib is not None + camera_calib = device_calib.get_camera_calib("camera-rgb") + assert camera_calib is not None + pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450) + + # Compute camera extrinsics! + sophus_T_device_camera = device_calib.get_transform_device_sensor("camera-rgb") + sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor("camera-rgb") + assert sophus_T_device_camera is not None + assert sophus_T_cpf_camera is not None + T_device_cam = np.concatenate( + [ + sophus_T_device_camera.rotation().to_quat().squeeze(axis=0), + sophus_T_device_camera.translation().squeeze(axis=0), + ] + ) + T_cpf_cam = np.concatenate( + [ + sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0), + sophus_T_cpf_camera.translation().squeeze(axis=0), + ] + ) + assert T_device_cam.shape == T_cpf_cam.shape == (7,) + + # Dict from capture timestamp in nanoseconds to fields we care about. + detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} + detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} + + pbar = tqdm(range(num_images)) + for i in pbar: + image_data, image_data_record = provider.get_image_data_by_index( + rgb_stream_id, i + ) + undistorted_image = calibration.distort_by_calibration( + image_data.to_numpy_array(), pinhole, camera_calib + ) + + hamer_out_left, hamer_out_right = hamer_helper.look_for_hands( + undistorted_image, + focal_length=450, + ) + timestamp_ns = image_data_record.capture_timestamp_ns + + if hamer_out_left is None: + detections_left_wrt_cam[timestamp_ns] = None + else: + detections_left_wrt_cam[timestamp_ns] = { + "verts": hamer_out_left["verts"], + "keypoints_3d": hamer_out_left["keypoints_3d"], + "mano_hand_pose": hamer_out_left["mano_hand_pose"], + "mano_hand_betas": hamer_out_left["mano_hand_betas"], + "mano_hand_global_orient": hamer_out_left["mano_hand_global_orient"], + } + + if hamer_out_right is None: + detections_right_wrt_cam[timestamp_ns] = None + else: + detections_right_wrt_cam[timestamp_ns] = { + "verts": hamer_out_right["verts"], + "keypoints_3d": hamer_out_right["keypoints_3d"], + "mano_hand_pose": hamer_out_right["mano_hand_pose"], + "mano_hand_betas": hamer_out_right["mano_hand_betas"], + "mano_hand_global_orient": hamer_out_right["mano_hand_global_orient"], + } + + composited = undistorted_image + composited = hamer_helper.composite_detections( + composited, + hamer_out_left, + border_color=(255, 100, 100), + focal_length=450, + ) + composited = hamer_helper.composite_detections( + composited, + hamer_out_right, + border_color=(100, 100, 255), + focal_length=450, + ) + composited = put_text( + composited, + "L detections: " + + ( + "0" if hamer_out_left is None else str(hamer_out_left["verts"].shape[0]) + ), + 0, + color=(255, 100, 100), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + composited = put_text( + composited, + "R detections: " + + ( + "0" + if hamer_out_right is None + else str(hamer_out_right["verts"].shape[0]) + ), + 1, + color=(100, 100, 255), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + composited = put_text( + composited, + f"ns={timestamp_ns}", + 2, + color=(255, 255, 255), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + + print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}") + iio.imwrite( + str(hamer_render_out / f"{i:06d}.jpeg"), + np.concatenate( + [ + # Darken input image, just for contrast... + (undistorted_image * 0.6).astype(np.uint8), + composited, + ], + axis=1, + ), + quality=90, + ) + + outputs = SavedHamerOutputs( + mano_faces_right=hamer_helper.get_mano_faces("right"), + mano_faces_left=hamer_helper.get_mano_faces("left"), + detections_right_wrt_cam=detections_right_wrt_cam, + detections_left_wrt_cam=detections_left_wrt_cam, + T_device_cam=T_device_cam, + T_cpf_cam=T_cpf_cam, + ) + with open(pickle_out, "wb") as f: + pickle.dump(outputs, f) + def run_aria_hamer_and_save( vrs_path: Path, pickle_out: Path, hamer_render_out: Path, wrist_and_palm_poses_path: Path, online_calib_path: Path, overwrite: bool diff --git a/_hamer_helper.py b/_hamer_helper.py new file mode 100644 index 0000000..3fcea7b --- /dev/null +++ b/_hamer_helper.py @@ -0,0 +1,830 @@ +import contextlib +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Generator, Literal, TypedDict + +import imageio.v3 as iio +import numpy as np +import torch +import torch.utils.data +from hamer.datasets.vitdet_dataset import DEFAULT_MEAN, DEFAULT_STD, ViTDetDataset +from hamer.utils.mesh_renderer import create_raymond_lights +from hamer.utils.renderer import Renderer, cam_crop_to_full +from jaxtyping import Float, Int +from scipy.ndimage import binary_dilation +from torch import Tensor + + +class HandOutputsWrtCamera(TypedDict): + """Hand outputs with respect to the camera frame.""" + + verts: Float[np.ndarray, "num_hands 778 3"] + keypoints_3d: Float[np.ndarray, "num_hands 21 3"] + mano_hand_pose: Float[np.ndarray, "num_hands 15 3 3"] + mano_hand_betas: Float[np.ndarray, "num_hands 10"] + mano_hand_global_orient: Float[np.ndarray, "num_hands 1 3 3"] + faces: Int[np.ndarray, "mesh_faces 3"] + + +@contextlib.contextmanager +def _stopwatch(message: str): + print("[STOPWATCH]", message) + start = time.time() + yield + print("[STOPWATCH]", message, f"finished in {time.time() - start} seconds!") + + +@dataclass(frozen=True) +class _RawHamerOutputs: + """A typed wrapper for outputs from HaMeR.""" + + # Comments here are what I got when printing out the shapes of different + # HaMeR outputs. + + # pred_cam torch.Size([1, 3]) + pred_cam: Float[Tensor, "num_hands 3"] + # pred_mano_params global_orient torch.Size([1, 1, 3, 3]) + pred_mano_global_orient: Float[Tensor, "num_hands 1 3 3"] + # pred_mano_params hand_pose torch.Size([1, 15, 3, 3]) + pred_mano_hand_pose: Float[Tensor, "num_hands 15 3 3"] + # pred_mano_params betas torch.Size([1, 10]) + pred_mano_hand_betas: Float[Tensor, "num_hands 10"] + # pred_cam_t torch.Size([1, 3]) + pred_cam_t: Float[Tensor, "num_hands 3"] + + # focal length from model is ignored + # focal_length torch.Size([1, 2]) + # focal_length: Float[Tensor, "num_hands 2"] + + # pred_keypoints_3d torch.Size([1, 21, 3]) + pred_keypoints_3d: Float[Tensor, "num_hands 21 3"] + # pred_vertices torch.Size([1, 778, 3]) + pred_vertices: Float[Tensor, "num_hands 778 3"] + # pred_keypoints_2d torch.Size([1, 21, 2]) + pred_keypoints_2d: Float[Tensor, "num_hands 21 3"] + + pred_right: Float[Tensor, "num_hands"] + """A given hand is a right hand if this value is >0.5.""" + + # These aren't technically HaMeR outputs, but putting them here for convenience. + mano_faces_right: Tensor + mano_faces_left: Tensor + + +@contextlib.contextmanager +def temporary_cwd_context(x: Path) -> Generator[None, None, None]: + """Temporarily change our working directory.""" + d = os.getcwd() + os.chdir(x) + try: + yield + finally: + os.chdir(d) + + +class HamerHelper: + """Helper class for running HaMeR. Adapted from HaMeR demo script.""" + + def __init__(self) -> None: + import hamer + from hamer.models import DEFAULT_CHECKPOINT, load_hamer + from vitpose_model import ViTPoseModel + + # HaMeR hardcodes a bunch of relative paths... + # Instead of modifying HaMeR we're going to hack this by temporarily changing our working directory :) + hamer_directory = Path(hamer.__file__).parent.parent + + with temporary_cwd_context(hamer_directory): + # Download and load checkpoints + # download_models(Path(hamer.__file__).parent.parent /CACHE_DIR_HAMER) + with _stopwatch("Loading HaMeR model..."): + model, model_cfg = load_hamer( + str(Path(hamer.__file__).parent.parent / DEFAULT_CHECKPOINT) + ) + + # Setup HaMeR model + with _stopwatch("Configuring HaMeR model..."): + device = torch.device("cuda") + model = model.to(device) + model.eval() + + # Load detector + import hamer + # from detectron2.config import LazyConfig + # from hamer.utils.utils_detectron2 import DefaultPredictor_Lazy + + # with _stopwatch("Creating Detectron2 predictor..."): + # cfg_path = ( + # Path(hamer.__file__).parent + # / "configs" + # / "cascade_mask_rcnn_vitdet_h_75ep.py" + # ) + # detectron2_cfg = LazyConfig.load(str(cfg_path)) + # detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl" # type: ignore + # for i in range(3): + # detectron2_cfg.model.roi_heads.box_predictors[ # type: ignore + # i + # ].test_score_thresh = 0.25 + # detector = DefaultPredictor_Lazy(detectron2_cfg) + + # keypoint detector + with _stopwatch("Creating ViT pose model..."): + cpm = ViTPoseModel(device) + + self._model = model + self._model_cfg = model_cfg + self._detector = None + self._cpm = cpm + self.device = device + + print("#" * 80) + print("#" * 80) + print("#" * 80) + print( + "Done setting up HaMeR! There were probably lots of errors, including a scary gigantic one about state dict stuff, but it's probably fine!" + ) + print("#" * 80) + print("#" * 80) + print("#" * 80) + + def get_default_focal_length(self, h: int, w: int) -> float: + """Get the default focal length for a given image size. + + This is how the HaMeR demo script computes the focal length... I don't + have a clear sense of the significance. We could ask George. + """ + return ( + self._model_cfg.EXTRA.FOCAL_LENGTH + / self._model_cfg.MODEL.IMAGE_SIZE + * max(h, w) + ) + + def get_det_from_boxes( + self, + image: Int[np.ndarray, "height width 3"], + ldetected: bool, + rdetected: bool, + l_box: np.ndarray, + r_box: np.ndarray, + focal_length: float | None = None, + rescale_factor: float = 2.0, + render_output_dir_for_testing: Path | None = None, + render_output_prefix_for_testing: str = "", + ) -> tuple[HandOutputsWrtCamera | None, HandOutputsWrtCamera | None]: + assert image.shape[-1] == 3 + + # image must be `np.uint8`, and in range [0, 255]. + assert image.dtype == np.uint8 + + # # Detectron expects BGR image. + # det_out = self._detector(image[:, :, ::-1]) + # det_instances = det_out["instances"] + # valid_idx = (det_instances.pred_classes == 0) & (det_instances.scores > 0.5) + # pred_bboxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy() + # pred_scores = det_instances.scores[valid_idx].cpu().numpy() + + + + # # Detect human keypoints for each person + # vitposes_out = self._cpm.predict_pose( + # image, + # [np.concatenate([pred_bboxes, pred_scores[:, None]], axis=1)], + # ) + + bboxes = [] + is_right = [] + + # # Use hands based on hand keypoint detections + # for vitposes in vitposes_out: + # left_hand_keyp = vitposes["keypoints"][-42:-21] + # right_hand_keyp = vitposes["keypoints"][-21:] + + # lbbox = None + # rbbox = None + + # # Rejecting not confident detections + # ldetect = rdetect = False + # keyp = left_hand_keyp + # valid = keyp[:, 2] > 0.5 + # if sum(valid) > 3: + # lbbox = [ + # keyp[valid, 0].min(), + # keyp[valid, 1].min(), + # keyp[valid, 0].max(), + # keyp[valid, 1].max(), + # ] + # ldetect = True + # keyp = right_hand_keyp + # valid = keyp[:, 2] > 0.5 + # if sum(valid) > 3: + # rbbox = [ + # keyp[valid, 0].min(), + # keyp[valid, 1].min(), + # keyp[valid, 0].max(), + # keyp[valid, 1].max(), + # ] + # rdetect = True + + # # suppressing + # if ldetect == True and rdetect == True: + # bboxes_dims = [ + # left_hand_keyp[:, 0].max() - left_hand_keyp[:, 0].min(), + # left_hand_keyp[:, 1].max() - left_hand_keyp[:, 1].min(), + # right_hand_keyp[:, 0].max() - right_hand_keyp[:, 0].min(), + # right_hand_keyp[:, 1].max() - right_hand_keyp[:, 1].min(), + # ] + # norm_side = max(bboxes_dims) + # keyp_dist = ( + # np.sqrt( + # np.sum( + # (right_hand_keyp[:, :2] - left_hand_keyp[:, :2]) ** 2, + # axis=1, + # ) + # ) + # / norm_side + # ) + # if np.mean(keyp_dist) < 0.5: + # if left_hand_keyp[0, 2] - right_hand_keyp[0, 2] > 0: + # assert lbbox is not None + # bboxes.append(lbbox) + # is_right.append(0) + # else: + # assert rbbox is not None + # bboxes.append(rbbox) + # is_right.append(1) + # else: + # assert lbbox is not None + # assert rbbox is not None + # bboxes.append(lbbox) + # is_right.append(0) + # bboxes.append(rbbox) + # is_right.append(1) + # elif ldetect == True: + # assert lbbox is not None + # bboxes.append(lbbox) + # is_right.append(0) + # elif rdetect == True: + # assert rbbox is not None + # bboxes.append(rbbox) + # is_right.append(1) + + # if len(bboxes) == 0: + # return None, None + + # boxes = np.stack(bboxes) + # right = np.stack(is_right) + + if ldetected: + bboxes.append(l_box) + is_right.append(0) + if rdetected: + bboxes.append(r_box) + is_right.append(1) + if len(bboxes) == 0: + return None, None + boxes = np.stack(bboxes) + right = np.stack(is_right) + + dataset = ViTDetDataset( + self._model_cfg, + # HaMeR expects BGR. + image[:, :, ::-1], + boxes, + right, + rescale_factor=rescale_factor, + ) + + # ViT detector will give us multiple detections. We want to run HaMeR + # on each. + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=8, shuffle=False, num_workers=0 + ) + outputs: list[_RawHamerOutputs] = [] + from hamer.utils import recursive_to + + for batch in dataloader: + batch: Any = recursive_to(batch, self.device) + with torch.no_grad(): + out = self._model.forward(batch) + + multiplier = 2 * batch["right"] - 1 + pred_cam = out["pred_cam"] + pred_cam[:, 1] = multiplier * pred_cam[:, 1] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + multiplier = 2 * batch["right"] - 1 + + if focal_length is None: + # All of the img_size rows should be the same. I think. + focal_length = float( + self.get_default_focal_length( + img_size[0, 0].item(), img_size[0, 1].item() + ) + ) + if isinstance(focal_length, int): + focal_length = float(focal_length) + assert isinstance(focal_length, float) + scaled_focal_length = focal_length + + pred_cam_t_full = cam_crop_to_full( + pred_cam, box_center, box_size, img_size, scaled_focal_length + ) + hamer_out = _RawHamerOutputs( + mano_faces_left=torch.from_numpy( + self._model.mano.faces[:, [0, 2, 1]].astype(np.int64) + ).to(device=self.device), + mano_faces_right=torch.from_numpy( + self._model.mano.faces.astype(np.int64) + ).to(device=self.device), + pred_cam=out["pred_cam"], + pred_mano_global_orient=out["pred_mano_params"]["global_orient"], + pred_mano_hand_pose=out["pred_mano_params"]["hand_pose"], + pred_mano_hand_betas=out["pred_mano_params"]["betas"], + pred_cam_t=pred_cam_t_full, + pred_keypoints_3d=out["pred_keypoints_3d"], + pred_vertices=out["pred_vertices"], + pred_keypoints_2d=out["pred_keypoints_2d"], + pred_right=batch["right"], + ) + + outputs.append(hamer_out) + + # Render the result. + if render_output_dir_for_testing: + renderer = Renderer(self._model_cfg, faces=self._model.mano.faces) + batch_size = batch["img"].shape[0] + for n in range(batch_size): + # Get filename from path img_path + person_id = int(batch["personid"][n]) + white_img = ( + torch.ones_like(batch["img"][n]).cpu() + - DEFAULT_MEAN[:, None, None] / 255 + ) / (DEFAULT_STD[:, None, None] / 255) + input_patch = batch["img"][n].cpu() * ( + DEFAULT_STD[:, None, None] / 255 + ) + (DEFAULT_MEAN[:, None, None] / 255) + input_patch = input_patch.permute(1, 2, 0).numpy() + + LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) + regression_img = renderer( + out["pred_vertices"][n].detach().cpu().numpy(), + out["pred_cam_t"][n].detach().cpu().numpy(), + batch["img"][n], + mesh_base_color=LIGHT_BLUE, + scene_bg_color=(1, 1, 1), + ) + + final_img = np.concatenate([input_patch, regression_img], axis=1) + + image_path = ( + render_output_dir_for_testing + / f"{render_output_prefix_for_testing}_hamer_{person_id}.png" + ) + print(f"Writing to {image_path}") + render_output_dir_for_testing.mkdir(exist_ok=True, parents=True) + iio.imwrite(image_path, (255 * final_img).astype(np.uint8)) + + # Add all verts and cams to list + verts = out["pred_vertices"][n].detach().cpu().numpy() + is_right = batch["right"][n].cpu().numpy() + verts[:, 0] = (2 * is_right - 1) * verts[:, 0] + + assert len(outputs) > 0 + stacked_outputs = _RawHamerOutputs( + **{ + field_name: torch.cat([getattr(x, field_name) for x in outputs], dim=0) + for field_name in vars(outputs[0]).keys() + }, + ) + # begin new brent stuff + verts = stacked_outputs.pred_vertices.numpy(force=True) + keypoints_3d = stacked_outputs.pred_keypoints_3d.numpy(force=True) + pred_cam_t = stacked_outputs.pred_cam_t.numpy(force=True) + mano_hand_pose = stacked_outputs.pred_mano_hand_pose.numpy(force=True) + mano_hand_betas = stacked_outputs.pred_mano_hand_betas.numpy(force=True) + R_camera_hand = stacked_outputs.pred_mano_global_orient.squeeze(dim=1).numpy( + force=True + ) + + is_right = (stacked_outputs.pred_right > 0.5).numpy(force=True) + is_left = ~is_right + + detections_right_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_right) == 0: + detections_right_wrt_cam = None + else: + detections_right_wrt_cam = { + "verts": verts[is_right] + pred_cam_t[is_right, None, :], + "keypoints_3d": keypoints_3d[is_right] + pred_cam_t[is_right, None, :], + "mano_hand_pose": mano_hand_pose[is_right], + "mano_hand_betas": mano_hand_betas[is_right], + "mano_hand_global_orient": R_camera_hand[is_right], + "faces": self.get_mano_faces("right"), + } + + detections_left_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_left) == 0: + detections_left_wrt_cam = None + else: + + def flip_rotmats(rotmats: np.ndarray) -> np.ndarray: + assert rotmats.shape[-2:] == (3, 3) + from viser import transforms + + logspace = transforms.SO3.from_matrix(rotmats).log() + logspace[..., 1] *= -1 + logspace[..., 2] *= -1 + return transforms.SO3.exp(logspace).as_matrix() + + detections_left_wrt_cam = { + "verts": verts[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "keypoints_3d": keypoints_3d[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "mano_hand_pose": flip_rotmats(mano_hand_pose[is_left]), + "mano_hand_betas": mano_hand_betas[is_left], + "mano_hand_global_orient": flip_rotmats(R_camera_hand[is_left]), + "faces": self.get_mano_faces("left"), + } + # end new brent stuff + return detections_left_wrt_cam, detections_right_wrt_cam + + + def look_for_hands( + self, + image: Int[np.ndarray, "height width 3"], + focal_length: float | None = None, + rescale_factor: float = 2.0, + render_output_dir_for_testing: Path | None = None, + render_output_prefix_for_testing: str = "", + ) -> tuple[HandOutputsWrtCamera | None, HandOutputsWrtCamera | None]: + """Look for hands. + + Arguments: + image: Image to look for hands in. Expects uint8, in range [0, 255]. + focal_length: Focal length of camera, used for 3D coordinates. + rescale_factor: Rescale factor for running ViT detector. I think 2 is fine, probably. + render_output_dir: Directory to render out detections to. Mostly this is used for testing. Doesn't do any rendering + """ + assert image.shape[-1] == 3 + + # image must be `np.uint8`, and in range [0, 255]. + assert image.dtype == np.uint8 + + # Detectron expects BGR image. + det_out = self._detector(image[:, :, ::-1]) + det_instances = det_out["instances"] + valid_idx = (det_instances.pred_classes == 0) & (det_instances.scores > 0.5) + pred_bboxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy() + pred_scores = det_instances.scores[valid_idx].cpu().numpy() + + # Detect human keypoints for each person + vitposes_out = self._cpm.predict_pose( + image, + [np.concatenate([pred_bboxes, pred_scores[:, None]], axis=1)], + ) + + bboxes = [] + is_right = [] + + # Use hands based on hand keypoint detections + for vitposes in vitposes_out: + left_hand_keyp = vitposes["keypoints"][-42:-21] + right_hand_keyp = vitposes["keypoints"][-21:] + + lbbox = None + rbbox = None + + # Rejecting not confident detections + ldetect = rdetect = False + keyp = left_hand_keyp + valid = keyp[:, 2] > 0.5 + if sum(valid) > 3: + lbbox = [ + keyp[valid, 0].min(), + keyp[valid, 1].min(), + keyp[valid, 0].max(), + keyp[valid, 1].max(), + ] + ldetect = True + keyp = right_hand_keyp + valid = keyp[:, 2] > 0.5 + if sum(valid) > 3: + rbbox = [ + keyp[valid, 0].min(), + keyp[valid, 1].min(), + keyp[valid, 0].max(), + keyp[valid, 1].max(), + ] + rdetect = True + + # suppressing + if ldetect == True and rdetect == True: + bboxes_dims = [ + left_hand_keyp[:, 0].max() - left_hand_keyp[:, 0].min(), + left_hand_keyp[:, 1].max() - left_hand_keyp[:, 1].min(), + right_hand_keyp[:, 0].max() - right_hand_keyp[:, 0].min(), + right_hand_keyp[:, 1].max() - right_hand_keyp[:, 1].min(), + ] + norm_side = max(bboxes_dims) + keyp_dist = ( + np.sqrt( + np.sum( + (right_hand_keyp[:, :2] - left_hand_keyp[:, :2]) ** 2, + axis=1, + ) + ) + / norm_side + ) + if np.mean(keyp_dist) < 0.5: + if left_hand_keyp[0, 2] - right_hand_keyp[0, 2] > 0: + assert lbbox is not None + bboxes.append(lbbox) + is_right.append(0) + else: + assert rbbox is not None + bboxes.append(rbbox) + is_right.append(1) + else: + assert lbbox is not None + assert rbbox is not None + bboxes.append(lbbox) + is_right.append(0) + bboxes.append(rbbox) + is_right.append(1) + elif ldetect == True: + assert lbbox is not None + bboxes.append(lbbox) + is_right.append(0) + elif rdetect == True: + assert rbbox is not None + bboxes.append(rbbox) + is_right.append(1) + + if len(bboxes) == 0: + return None, None + + boxes = np.stack(bboxes) + right = np.stack(is_right) + + dataset = ViTDetDataset( + self._model_cfg, + # HaMeR expects BGR. + image[:, :, ::-1], + boxes, + right, + rescale_factor=rescale_factor, + ) + + # ViT detector will give us multiple detections. We want to run HaMeR + # on each. + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=8, shuffle=False, num_workers=0 + ) + outputs: list[_RawHamerOutputs] = [] + from hamer.utils import recursive_to + + for batch in dataloader: + batch: Any = recursive_to(batch, self.device) + with torch.no_grad(): + out = self._model.forward(batch) + + multiplier = 2 * batch["right"] - 1 + pred_cam = out["pred_cam"] + pred_cam[:, 1] = multiplier * pred_cam[:, 1] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + multiplier = 2 * batch["right"] - 1 + + if focal_length is None: + # All of the img_size rows should be the same. I think. + focal_length = float( + self.get_default_focal_length( + img_size[0, 0].item(), img_size[0, 1].item() + ) + ) + if isinstance(focal_length, int): + focal_length = float(focal_length) + assert isinstance(focal_length, float) + scaled_focal_length = focal_length + + pred_cam_t_full = cam_crop_to_full( + pred_cam, box_center, box_size, img_size, scaled_focal_length + ) + hamer_out = _RawHamerOutputs( + mano_faces_left=torch.from_numpy( + self._model.mano.faces[:, [0, 2, 1]].astype(np.int64) + ).to(device=self.device), + mano_faces_right=torch.from_numpy( + self._model.mano.faces.astype(np.int64) + ).to(device=self.device), + pred_cam=out["pred_cam"], + pred_mano_global_orient=out["pred_mano_params"]["global_orient"], + pred_mano_hand_pose=out["pred_mano_params"]["hand_pose"], + pred_mano_hand_betas=out["pred_mano_params"]["betas"], + pred_cam_t=pred_cam_t_full, + pred_keypoints_3d=out["pred_keypoints_3d"], + pred_vertices=out["pred_vertices"], + pred_keypoints_2d=out["pred_keypoints_2d"], + pred_right=batch["right"], + ) + + outputs.append(hamer_out) + + # Render the result. + if render_output_dir_for_testing: + renderer = Renderer(self._model_cfg, faces=self._model.mano.faces) + batch_size = batch["img"].shape[0] + for n in range(batch_size): + # Get filename from path img_path + person_id = int(batch["personid"][n]) + white_img = ( + torch.ones_like(batch["img"][n]).cpu() + - DEFAULT_MEAN[:, None, None] / 255 + ) / (DEFAULT_STD[:, None, None] / 255) + input_patch = batch["img"][n].cpu() * ( + DEFAULT_STD[:, None, None] / 255 + ) + (DEFAULT_MEAN[:, None, None] / 255) + input_patch = input_patch.permute(1, 2, 0).numpy() + + LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) + regression_img = renderer( + out["pred_vertices"][n].detach().cpu().numpy(), + out["pred_cam_t"][n].detach().cpu().numpy(), + batch["img"][n], + mesh_base_color=LIGHT_BLUE, + scene_bg_color=(1, 1, 1), + ) + + final_img = np.concatenate([input_patch, regression_img], axis=1) + + image_path = ( + render_output_dir_for_testing + / f"{render_output_prefix_for_testing}_hamer_{person_id}.png" + ) + print(f"Writing to {image_path}") + render_output_dir_for_testing.mkdir(exist_ok=True, parents=True) + iio.imwrite(image_path, (255 * final_img).astype(np.uint8)) + + # Add all verts and cams to list + verts = out["pred_vertices"][n].detach().cpu().numpy() + is_right = batch["right"][n].cpu().numpy() + verts[:, 0] = (2 * is_right - 1) * verts[:, 0] + + assert len(outputs) > 0 + stacked_outputs = _RawHamerOutputs( + **{ + field_name: torch.cat([getattr(x, field_name) for x in outputs], dim=0) + for field_name in vars(outputs[0]).keys() + }, + ) + # begin new brent stuff + verts = stacked_outputs.pred_vertices.numpy(force=True) + keypoints_3d = stacked_outputs.pred_keypoints_3d.numpy(force=True) + pred_cam_t = stacked_outputs.pred_cam_t.numpy(force=True) + mano_hand_pose = stacked_outputs.pred_mano_hand_pose.numpy(force=True) + mano_hand_betas = stacked_outputs.pred_mano_hand_betas.numpy(force=True) + R_camera_hand = stacked_outputs.pred_mano_global_orient.squeeze(dim=1).numpy( + force=True + ) + + is_right = (stacked_outputs.pred_right > 0.5).numpy(force=True) + is_left = ~is_right + + detections_right_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_right) == 0: + detections_right_wrt_cam = None + else: + detections_right_wrt_cam = { + "verts": verts[is_right] + pred_cam_t[is_right, None, :], + "keypoints_3d": keypoints_3d[is_right] + pred_cam_t[is_right, None, :], + "mano_hand_pose": mano_hand_pose[is_right], + "mano_hand_betas": mano_hand_betas[is_right], + "mano_hand_global_orient": R_camera_hand[is_right], + "faces": self.get_mano_faces("right"), + } + + detections_left_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_left) == 0: + detections_left_wrt_cam = None + else: + + def flip_rotmats(rotmats: np.ndarray) -> np.ndarray: + assert rotmats.shape[-2:] == (3, 3) + from viser import transforms + + logspace = transforms.SO3.from_matrix(rotmats).log() + logspace[..., 1] *= -1 + logspace[..., 2] *= -1 + return transforms.SO3.exp(logspace).as_matrix() + + detections_left_wrt_cam = { + "verts": verts[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "keypoints_3d": keypoints_3d[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "mano_hand_pose": flip_rotmats(mano_hand_pose[is_left]), + "mano_hand_betas": mano_hand_betas[is_left], + "mano_hand_global_orient": flip_rotmats(R_camera_hand[is_left]), + "faces": self.get_mano_faces("left"), + } + # end new brent stuff + return detections_left_wrt_cam, detections_right_wrt_cam + + def get_mano_faces(self, side: Literal["left", "right"]) -> np.ndarray: + if side == "left": + return self._model.mano.faces[:, [0, 2, 1]].copy() + else: + return self._model.mano.faces.copy() + + def render_detection( + self, + output_dict: HandOutputsWrtCamera, + hand_index: int, + h: int, + w: int, + focal_length: float | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Render to a tuple of (RGB, depth, mask). For testing.""" + import pyrender + import trimesh + + if focal_length is None: + focal_length = self.get_default_focal_length(h, w) + + render_res = (h, w) + renderer = pyrender.OffscreenRenderer( + viewport_width=render_res[1], viewport_height=render_res[0], point_size=1.0 + ) + material = pyrender.MetallicRoughnessMaterial( + metallicFactor=0.0, alphaMode="OPAQUE", baseColorFactor=(1.0, 1.0, 0.9, 1.0) + ) + + vertices = output_dict["verts"][hand_index] + faces = output_dict["faces"] + + mesh = trimesh.Trimesh(vertices.copy(), faces.copy()) + mesh = pyrender.Mesh.from_trimesh(mesh, material=material) + + scene = pyrender.Scene( + bg_color=[1.0, 1.0, 1.0, 0.0], ambient_light=(0.3, 0.3, 0.3) + ) + scene.add(mesh, "mesh") + + camera_center = [render_res[1] / 2.0, render_res[0] / 2.0] + camera = pyrender.IntrinsicsCamera( + fx=focal_length, + fy=focal_length, + cx=camera_center[0], + cy=camera_center[1], + zfar=1e12, + znear=0.001, + ) + + light_nodes = create_raymond_lights() + for node in light_nodes: + scene.add_node(node) + + # Create camera node and add it to pyRender scene + camera_pose = np.eye(4) + camera_pose[1:3, :] *= -1 # flip the y and z axes to match opengl + camera_node = pyrender.Node(camera=camera, matrix=camera_pose) + scene.add_node(camera_node) + + color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) # type: ignore + mask = color[..., -1] > 0 + return color[..., :3], rend_depth, mask + + def composite_detections( + self, + image: np.ndarray, + detections: HandOutputsWrtCamera | None, + border_color: tuple[int, int, int], + focal_length: float | None = None, + ) -> np.ndarray: + """Render some hand detections on top of an image. Returns an updated image.""" + if detections is None: + return image + + h, w = image.shape[:2] + + for index in range(detections["verts"].shape[0]): + print(index) + render_rgb, _, render_mask = self.render_detection( + detections, hand_index=index, h=h, w=w, focal_length=focal_length + ) + border_width = 15 + image = np.where( + binary_dilation( + render_mask, np.ones((border_width, border_width), dtype=bool) + )[:, :, None], + np.zeros_like(render_rgb) + np.array(border_color, dtype=np.uint8), + image, + ) + image = np.where(render_mask[:, :, None], render_rgb, image) + + return image From 9304ef9f1d0391c2fc3a2c749f35b38eb2300b30 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Mon, 28 Oct 2024 06:51:15 -0700 Subject: [PATCH 3/8] wilor --- 2_run_hamer_on_vrs.py | 121 +++++------- 3_aria_inference.py | 7 +- _hamer_helper.py | 40 ++-- _wilor_helper.py | 345 +++++++++++++++++++++++++++++++++ src/egoallo/inference_utils.py | 6 +- 5 files changed, 422 insertions(+), 97 deletions(-) create mode 100644 _wilor_helper.py diff --git a/2_run_hamer_on_vrs.py b/2_run_hamer_on_vrs.py index 0c935f2..fbc3ae5 100644 --- a/2_run_hamer_on_vrs.py +++ b/2_run_hamer_on_vrs.py @@ -12,61 +12,66 @@ SavedHamerOutputs, SingleHandHamerOutputWrtCamera, ) -from _hamer_helper import HamerHelper + from projectaria_tools.core import calibration from projectaria_tools.core.data_provider import ( VrsDataProvider, create_vrs_data_provider, ) +from _hamer_helper import HamerHelper import projectaria_tools.core.mps as mps from projectaria_tools.core.sensor_data import DEVICE_TIME from tqdm.auto import tqdm from egoallo.inference_utils import InferenceTrajectoryPaths -from aria_utils import per_image_hand_tracking,get_online_calib,x_y_around -def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False) -> None: +def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False, wilor_home: str = "/secondary/home/annie/code/WiLoR") -> None: """Run HaMeR for on trajectory. We'll save outputs to `traj_root/hamer_outputs.pkl` and `traj_root/hamer_outputs_render". Arguments: traj_root: The root directory of the trajectory. We assume that there's a VRS file in this directory. - detector: The detector to use. Can be "WiLoR", "aria", or "hamer". + detector: The detector to use. Can be "wilor", "aria", or "hamer". overwrite: If True, overwrite any existing HaMeR outputs. + wilor_home: The path to the WiLoR home directory. Only used if `detector` + is "wilor". """ paths = InferenceTrajectoryPaths.find(traj_root) - vrs_path = paths.vrs_file assert vrs_path.exists() - pickle_out = traj_root / "hamer_outputs.pkl" - hamer_render_out = traj_root / "hamer_outputs_render" # This is just for debugging. - wrist_and_palm_poses_path = traj_root / "hand_tracking/wrist_and_palm_poses.csv" - online_path = traj_root / "slam/online_calibration.jsonl" - # run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) - if detector == "WiLoR": - run_wilor_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) + + if detector == "wilor": + render_out_path = traj_root / "wilor_outputs_unrot_render" # This is just for debugging. + pickle_out = traj_root / "wilor_outputs.pkl" + run_wilor_and_save(vrs_path, pickle_out, render_out_path, overwrite, wilor_home) elif detector == "aria": - run_aria_hamer_and_save(vrs_path, pickle_out, hamer_render_out, wrist_and_palm_poses_path, online_path, overwrite) + render_out_path = traj_root / "hamer_aria_outputs_render" # This is just for debugging. + wrist_and_palm_poses_path = traj_root / "hand_tracking/wrist_and_palm_poses.csv" + online_path = traj_root / "slam/online_calibration.jsonl" + pickle_out = traj_root / "aria_outputs.pkl" + run_aria_hamer_and_save(vrs_path, pickle_out, render_out_path, wrist_and_palm_poses_path, online_path, overwrite) elif detector == "hamer": - run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) + render_out_path = traj_root / "hamer_outputs_render" # This is just for debugging. + pickle_out = traj_root / "hamer_outputs.pkl" + run_hamer_and_save(vrs_path, pickle_out, render_out_path, overwrite) else: raise ValueError(f"Unknown detector: {detector}") def run_hamer_and_save( - vrs_path: Path, pickle_out: Path, hamer_render_out: Path, overwrite: bool + vrs_path: Path, pickle_out: Path, render_out_path: Path, overwrite: bool ) -> None: if not overwrite: assert not pickle_out.exists() - assert not hamer_render_out.exists() + assert not render_out_path.exists() else: pickle_out.unlink(missing_ok=True) - shutil.rmtree(hamer_render_out, ignore_errors=True) + shutil.rmtree(render_out_path, ignore_errors=True) - hamer_render_out.mkdir(exist_ok=True) + render_out_path.mkdir(exist_ok=True) hamer_helper = HamerHelper() # VRS data provider setup. @@ -188,9 +193,9 @@ def run_hamer_and_save( font_scale=10.0 / 2880.0 * undistorted_image.shape[0], ) - print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}") + print(f"Saving image {i:06d} to {render_out_path / f'{i:06d}.jpeg'}") iio.imwrite( - str(hamer_render_out / f"{i:06d}.jpeg"), + str(render_out_path / f"{i:06d}.jpeg"), np.concatenate( [ # Darken input image, just for contrast... @@ -214,17 +219,17 @@ def run_hamer_and_save( pickle.dump(outputs, f) def run_wilor_and_save( - vrs_path: Path, pickle_out: Path, hamer_render_out: Path, overwrite: bool + vrs_path: Path, pickle_out: Path, render_out_path: Path, overwrite: bool, wilor_home: str ) -> None: - raise NotImplementedError("WiLoR is not implemented yet.") + from _wilor_helper import WiLoRHelper if not overwrite: assert not pickle_out.exists() - assert not hamer_render_out.exists() + assert not render_out_path.exists() else: pickle_out.unlink(missing_ok=True) - shutil.rmtree(hamer_render_out, ignore_errors=True) - - hamer_render_out.mkdir(exist_ok=True) + shutil.rmtree(render_out_path, ignore_errors=True) + render_out_path.mkdir(exist_ok=True) + wilor_helper = WiLoRHelper(wilor_home) hamer_helper = HamerHelper() # VRS data provider setup. @@ -274,8 +279,8 @@ def run_wilor_and_save( undistorted_image = calibration.distort_by_calibration( image_data.to_numpy_array(), pinhole, camera_calib ) - - hamer_out_left, hamer_out_right = hamer_helper.look_for_hands( + + hamer_out_left, hamer_out_right = wilor_helper.look_for_hands( undistorted_image, focal_length=450, ) @@ -346,9 +351,9 @@ def run_wilor_and_save( font_scale=10.0 / 2880.0 * undistorted_image.shape[0], ) - print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}") + print(f"Saving image {i:06d} to {render_out_path / f'{i:06d}.jpeg'}") iio.imwrite( - str(hamer_render_out / f"{i:06d}.jpeg"), + str(render_out_path / f"{i:06d}.jpeg"), np.concatenate( [ # Darken input image, just for contrast... @@ -361,8 +366,8 @@ def run_wilor_and_save( ) outputs = SavedHamerOutputs( - mano_faces_right=hamer_helper.get_mano_faces("right"), - mano_faces_left=hamer_helper.get_mano_faces("left"), + mano_faces_right=wilor_helper.get_mano_faces("right"), + mano_faces_left=wilor_helper.get_mano_faces("left"), detections_right_wrt_cam=detections_right_wrt_cam, detections_left_wrt_cam=detections_left_wrt_cam, T_device_cam=T_device_cam, @@ -373,16 +378,17 @@ def run_wilor_and_save( def run_aria_hamer_and_save( - vrs_path: Path, pickle_out: Path, hamer_render_out: Path, wrist_and_palm_poses_path: Path, online_calib_path: Path, overwrite: bool + vrs_path: Path, pickle_out: Path, render_out_path: Path, wrist_and_palm_poses_path: Path, online_calib_path: Path, overwrite: bool ) -> None: + from aria_utils import per_image_hand_tracking,get_online_calib,x_y_around if not overwrite: assert not pickle_out.exists() - assert not hamer_render_out.exists() + assert not render_out_path.exists() else: pickle_out.unlink(missing_ok=True) - shutil.rmtree(hamer_render_out, ignore_errors=True) + shutil.rmtree(render_out_path, ignore_errors=True) - hamer_render_out.mkdir(exist_ok=True) + render_out_path.mkdir(exist_ok=True) hamer_helper = HamerHelper() # VRS data provider setup. @@ -431,10 +437,6 @@ def run_aria_hamer_and_save( wrist_and_palm_poses = mps.hand_tracking.read_wrist_and_palm_poses(wrist_and_palm_poses_path) pbar = tqdm(range(num_images)) - - l_point_queue=[] - r_point_queue=[] - queue_length=5 for i in pbar: image_data, image_data_record = provider.get_image_data_by_index( @@ -448,43 +450,14 @@ def run_aria_hamer_and_save( l_existed, r_existed, l_point, r_point = per_image_hand_tracking(timestamp_ns, wrist_and_palm_poses, pinhole, camera_calib, rgb_calib) if l_existed: l_box = x_y_around(l_point[0], l_point[1],pinhole,offset=80) - l_point_queue.append(l_point) else: - l_box=None - # for index_l1 in range(len(l_point_queue)-1,-1,-1): - # if l_point_queue[index_l1] is not None: - # for index_l2 in range(index_l1-1,-1,-1): - # if l_point_queue[index_l2] is not None: - # l_point = (len(l_point_queue)-index_l1)*(l_point_queue[index_l1]-l_point_queue[index_l2])/(index_l1-index_l2)+l_point_queue[index_l1] - # l_box = x_y_around(l_point[0], l_point[1],pinhole) - # l_existed=True - # # print("use previous l:",len(l_point_queue)-index_l1,len(l_point_queue)-index_l2) - # break - # if l_existed: - # break - l_point_queue.append(None) + l_box = None if r_existed: r_box = x_y_around(r_point[0], r_point[1],pinhole,offset=80) - r_point_queue.append(r_point) else: - r_box=None - # for index_r1 in range(len(r_point_queue)-1,-1,-1): - # if r_point_queue[index_r1] is not None: - # for index_r2 in range(index_r1-1,-1,-1): - # if r_point_queue[index_r2] is not None: - # r_point = (len(r_point_queue)-index_r1)*(r_point_queue[index_r1]-r_point_queue[index_r2])/(index_r1-index_r2)+r_point_queue[index_r1] - # r_box = x_y_around(r_point[0], r_point[1],pinhole) - # r_existed=True - # # print("use previous r:",len(r_point_queue)-index_r1,len(r_point_queue)-index_r2) - # break - # if r_existed: - # break - r_point_queue.append(None) - - if len(l_point_queue)>queue_length: - l_point_queue.pop(0) - r_point_queue.pop(0) + r_box = None + hamer_out_left, hamer_out_right = hamer_helper.get_det_from_boxes( undistorted_image, @@ -560,7 +533,7 @@ def run_aria_hamer_and_save( font_scale=10.0 / 2880.0 * undistorted_image.shape[0], ) - print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}") + print(f"Saving image {i:06d} to {render_out_path / f'{i:06d}.jpeg'}") # bbox on undistorted image if l_existed: min_l_p_x, min_l_p_y, max_l_p_x, max_l_p_y = l_box @@ -575,7 +548,7 @@ def run_aria_hamer_and_save( iio.imwrite( - str(hamer_render_out / f"{i:06d}.jpeg"), + str(render_out_path / f"{i:06d}.jpeg"), np.concatenate( [ # Darken input image, just for contrast... diff --git a/3_aria_inference.py b/3_aria_inference.py index c489011..fe00975 100644 --- a/3_aria_inference.py +++ b/3_aria_inference.py @@ -38,6 +38,8 @@ class Args: ... ... """ + detector: str = "hamer" + """for choosing pkl file. choose from ["hamer", "aria", "wilor"]""" checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/") smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz") @@ -66,7 +68,7 @@ class Args: def main(args: Args) -> None: device = torch.device("cuda") - traj_paths = InferenceTrajectoryPaths.find(args.traj_root) + traj_paths = InferenceTrajectoryPaths.find(args.traj_root, detector=args.detector) if traj_paths.splat_path is not None: print("Found splat at", traj_paths.splat_path) else: @@ -148,7 +150,8 @@ def main(args: Args) -> None: # Save outputs in case we want to visualize later. if args.save_traj: save_name = ( - time.strftime("%Y%m%d-%H%M%S") + args.detector + + time.strftime("%Y%m%d-%H%M%S") + f"_{args.start_index}-{args.start_index + args.traj_length}" ) out_path = args.traj_root / "egoallo_outputs" / (save_name + ".npz") diff --git a/_hamer_helper.py b/_hamer_helper.py index 3fcea7b..5a3c16b 100644 --- a/_hamer_helper.py +++ b/_hamer_helper.py @@ -87,7 +87,7 @@ def temporary_cwd_context(x: Path) -> Generator[None, None, None]: class HamerHelper: """Helper class for running HaMeR. Adapted from HaMeR demo script.""" - def __init__(self) -> None: + def __init__(self,other_detector=False) -> None: import hamer from hamer.models import DEFAULT_CHECKPOINT, load_hamer from vitpose_model import ViTPoseModel @@ -112,22 +112,27 @@ def __init__(self) -> None: # Load detector import hamer - # from detectron2.config import LazyConfig - # from hamer.utils.utils_detectron2 import DefaultPredictor_Lazy - - # with _stopwatch("Creating Detectron2 predictor..."): - # cfg_path = ( - # Path(hamer.__file__).parent - # / "configs" - # / "cascade_mask_rcnn_vitdet_h_75ep.py" - # ) - # detectron2_cfg = LazyConfig.load(str(cfg_path)) - # detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl" # type: ignore - # for i in range(3): - # detectron2_cfg.model.roi_heads.box_predictors[ # type: ignore - # i - # ].test_score_thresh = 0.25 - # detector = DefaultPredictor_Lazy(detectron2_cfg) + + if other_detector: + self._detector = None + else: + from detectron2.config import LazyConfig + from hamer.utils.utils_detectron2 import DefaultPredictor_Lazy + + with _stopwatch("Creating Detectron2 predictor..."): + cfg_path = ( + Path(hamer.__file__).parent + / "configs" + / "cascade_mask_rcnn_vitdet_h_75ep.py" + ) + detectron2_cfg = LazyConfig.load(str(cfg_path)) + detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl" # type: ignore + for i in range(3): + detectron2_cfg.model.roi_heads.box_predictors[ # type: ignore + i + ].test_score_thresh = 0.25 + detector = DefaultPredictor_Lazy(detectron2_cfg) + self._detector = detector # keypoint detector with _stopwatch("Creating ViT pose model..."): @@ -135,7 +140,6 @@ def __init__(self) -> None: self._model = model self._model_cfg = model_cfg - self._detector = None self._cpm = cpm self.device = device diff --git a/_wilor_helper.py b/_wilor_helper.py new file mode 100644 index 0000000..2ae7143 --- /dev/null +++ b/_wilor_helper.py @@ -0,0 +1,345 @@ +from pathlib import Path +import torch +import argparse +import os +import cv2 +import numpy as np +from dataclasses import dataclass +from torch import Tensor +from typing import Any, Generator, Literal, TypedDict + +from wilor.models import load_wilor +from wilor.utils import recursive_to +from wilor.datasets.vitdet_dataset import ViTDetDataset +from wilor.utils.renderer import Renderer, cam_crop_to_full +from ultralytics import YOLO +from jaxtyping import Float, Int +LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353) + + +# same in _hamer_helper.py +class HandOutputsWrtCamera(TypedDict): + """Hand outputs with respect to the camera frame.""" + + verts: Float[np.ndarray, "num_hands 778 3"] + keypoints_3d: Float[np.ndarray, "num_hands 21 3"] + mano_hand_pose: Float[np.ndarray, "num_hands 15 3 3"] + mano_hand_betas: Float[np.ndarray, "num_hands 10"] + mano_hand_global_orient: Float[np.ndarray, "num_hands 1 3 3"] + faces: Int[np.ndarray, "mesh_faces 3"] + +# same in _hamer_helper.py +@dataclass(frozen=True) +class _RawHamerOutputs: + """A typed wrapper for outputs from HaMeR.""" + + # Comments here are what I got when printing out the shapes of different + # HaMeR outputs. + + # pred_cam torch.Size([1, 3]) + pred_cam: Float[Tensor, "num_hands 3"] + # pred_mano_params global_orient torch.Size([1, 1, 3, 3]) + pred_mano_global_orient: Float[Tensor, "num_hands 1 3 3"] + # pred_mano_params hand_pose torch.Size([1, 15, 3, 3]) + pred_mano_hand_pose: Float[Tensor, "num_hands 15 3 3"] + # pred_mano_params betas torch.Size([1, 10]) + pred_mano_hand_betas: Float[Tensor, "num_hands 10"] + # pred_cam_t torch.Size([1, 3]) + pred_cam_t: Float[Tensor, "num_hands 3"] + + # focal length from model is ignored + # focal_length torch.Size([1, 2]) + # focal_length: Float[Tensor, "num_hands 2"] + + # pred_keypoints_3d torch.Size([1, 21, 3]) + pred_keypoints_3d: Float[Tensor, "num_hands 21 3"] + # pred_vertices torch.Size([1, 778, 3]) + pred_vertices: Float[Tensor, "num_hands 778 3"] + # pred_keypoints_2d torch.Size([1, 21, 2]) + pred_keypoints_2d: Float[Tensor, "num_hands 21 3"] + + pred_right: Float[Tensor, "num_hands"] + """A given hand is a right hand if this value is >0.5.""" + + # These aren't technically HaMeR outputs, but putting them here for convenience. + mano_faces_right: Tensor + mano_faces_left: Tensor + +class WiLoRHelper: + def __init__(self, wilor_home: str="./"): + model, model_cfg = load_wilor(wilor_home, checkpoint_path = './pretrained_models/wilor_final.ckpt' , + cfg_path= './pretrained_models/model_config.yaml') + detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt')) + # Setup the renderer + self.renderer = Renderer(model_cfg, faces=model.mano.faces) + + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self._model = model.to(self.device) + self._model_cfg = model_cfg + self._detector = detector.to(self.device) + + def look_for_hands( + self, + image: np.ndarray, + focal_length: float | None = None, + rescale_factor: float = 2.0, + ) -> tuple[HandOutputsWrtCamera | None, HandOutputsWrtCamera | None]: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + detections = self._detector(image, conf = 0.3, verbose=False)[0] + # breakpoint() + bboxes = [] + is_right = [] + for det in detections: + Bbox = det.boxes.data.cpu().detach().squeeze().numpy() + is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) + bboxes.append(Bbox[:4].tolist()) + + if len(bboxes) == 0: + return None, None + boxes = np.stack(bboxes) + right = np.stack(is_right) + dataset = ViTDetDataset(self._model_cfg, image, boxes, right, rescale_factor=rescale_factor) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) + + outputs: list[_RawHamerOutputs] = [] + + for batch in dataloader: + batch = recursive_to(batch, self.device) + + with torch.no_grad(): + out = self._model(batch) + + multiplier = (2*batch['right']-1) + pred_cam = out['pred_cam'] + pred_cam[:,1] = multiplier*pred_cam[:,1] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + if focal_length is None: + # All of the img_size rows should be the same. I think. + focal_length = self._model_cfg.EXTRA.FOCAL_LENGTH / self._model_cfg.MODEL.IMAGE_SIZE * img_size.max() + focal_length = float(focal_length) + scaled_focal_length = focal_length + pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy() + pred_cam_t_full = torch.tensor(pred_cam_t_full, device=self.device) + hamer_out = _RawHamerOutputs( + mano_faces_left=torch.from_numpy( + self._model.mano.faces[:, [0, 2, 1]].astype(np.int64) + ).to(device=self.device), + mano_faces_right=torch.from_numpy( + self._model.mano.faces.astype(np.int64) + ).to(device=self.device), + pred_cam=out["pred_cam"], + pred_mano_global_orient=out["pred_mano_params"]["global_orient"], + pred_mano_hand_pose=out["pred_mano_params"]["hand_pose"], + pred_mano_hand_betas=out["pred_mano_params"]["betas"], + pred_cam_t=pred_cam_t_full, + pred_keypoints_3d=out["pred_keypoints_3d"], + pred_vertices=out["pred_vertices"], + pred_keypoints_2d=out["pred_keypoints_2d"], + pred_right=batch["right"], + ) + + outputs.append(hamer_out) + + assert len(outputs) > 0 + + stacked_outputs = _RawHamerOutputs( + **{ + field_name: torch.cat([getattr(x, field_name) for x in outputs], dim=0) + for field_name in vars(outputs[0]).keys() + }, + ) + # begin new brent stuff + verts = stacked_outputs.pred_vertices.numpy(force=True) + keypoints_3d = stacked_outputs.pred_keypoints_3d.numpy(force=True) + pred_cam_t = stacked_outputs.pred_cam_t.numpy(force=True) + mano_hand_pose = stacked_outputs.pred_mano_hand_pose.numpy(force=True) + mano_hand_betas = stacked_outputs.pred_mano_hand_betas.numpy(force=True) + R_camera_hand = stacked_outputs.pred_mano_global_orient.squeeze(dim=1).numpy( + force=True + ) + + is_right = (stacked_outputs.pred_right > 0.5).numpy(force=True) + is_left = ~is_right + + detections_right_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_right) == 0: + detections_right_wrt_cam = None + else: + detections_right_wrt_cam = { + "verts": verts[is_right] + pred_cam_t[is_right, None, :], + "keypoints_3d": keypoints_3d[is_right] + pred_cam_t[is_right, None, :], + "mano_hand_pose": mano_hand_pose[is_right], + "mano_hand_betas": mano_hand_betas[is_right], + "mano_hand_global_orient": R_camera_hand[is_right], + "faces": self.get_mano_faces("right"), + } + + detections_left_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_left) == 0: + detections_left_wrt_cam = None + else: + + def flip_rotmats(rotmats: np.ndarray) -> np.ndarray: + assert rotmats.shape[-2:] == (3, 3) + from viser import transforms + + logspace = transforms.SO3.from_matrix(rotmats).log() + logspace[..., 1] *= -1 + logspace[..., 2] *= -1 + return transforms.SO3.exp(logspace).as_matrix() + + detections_left_wrt_cam = { + "verts": verts[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "keypoints_3d": keypoints_3d[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "mano_hand_pose": flip_rotmats(mano_hand_pose[is_left]), + "mano_hand_betas": mano_hand_betas[is_left], + "mano_hand_global_orient": flip_rotmats(R_camera_hand[is_left]), + "faces": self.get_mano_faces("left"), + } + # end new brent stuff + return detections_left_wrt_cam, detections_right_wrt_cam + + def get_mano_faces(self, side: Literal["left", "right"]) -> np.ndarray: + if side == "left": + return self._model.mano.faces[:, [0, 2, 1]].copy() + else: + return self._model.mano.faces.copy() + +def example_demo(): + parser = argparse.ArgumentParser(description='WiLoR demo code') + parser.add_argument('--img_folder', type=str, default='images', help='Folder with input images') + parser.add_argument('--out_folder', type=str, default='out_demo', help='Output folder to save rendered results') + parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='If set, save meshes to disk also') + parser.add_argument('--rescale_factor', type=float, default=2.0, help='Factor for padding the bbox') + parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg'], help='List of file extensions to consider') + + args = parser.parse_args() + + # Download and load checkpoints + model, model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml') + detector = YOLO('./pretrained_models/detector.pt') + # Setup the renderer + renderer = Renderer(model_cfg, faces=model.mano.faces) + renderer_side = Renderer(model_cfg, faces=model.mano.faces) + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = model.to(device) + detector = detector.to(device) + model.eval() + + # Make output directory if it does not exist + os.makedirs(args.out_folder, exist_ok=True) + + # Get all demo images ends with .jpg or .png + img_paths = [img for end in args.file_type for img in Path(args.img_folder).glob(end)] + # Iterate over all images in folder + for img_path in img_paths: + img_cv2 = cv2.imread(str(img_path)) + detections = detector(img_cv2, conf = 0.3, verbose=False)[0] + bboxes = [] + is_right = [] + for det in detections: + Bbox = det.boxes.data.cpu().detach().squeeze().numpy() + is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) + bboxes.append(Bbox[:4].tolist()) + + if len(bboxes) == 0: + # basename = os.path.basename(img_path).split('.')[0] + # cv2.imwrite(os.path.join(args.out_folder, f'{basename}.jpg'), img_cv2) + # print(os.path.join(args.out_folder, f'{basename}.jpg')) + continue + boxes = np.stack(bboxes) + right = np.stack(is_right) + dataset = ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=args.rescale_factor) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) + + all_verts = [] + all_cam_t = [] + all_right = [] + all_joints= [] + all_kpts = [] + + for batch in dataloader: + batch = recursive_to(batch, device) + + with torch.no_grad(): + out = model(batch) + + multiplier = (2*batch['right']-1) + pred_cam = out['pred_cam'] + pred_cam[:,1] = multiplier*pred_cam[:,1] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max() + pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy() + + + # Render the result + batch_size = batch['img'].shape[0] + for n in range(batch_size): + # Get filename from path img_path + img_fn, _ = os.path.splitext(os.path.basename(img_path)) + + verts = out['pred_vertices'][n].detach().cpu().numpy() + joints = out['pred_keypoints_3d'][n].detach().cpu().numpy() + + is_right = batch['right'][n].cpu().numpy() + verts[:,0] = (2*is_right-1)*verts[:,0] + joints[:,0] = (2*is_right-1)*joints[:,0] + cam_t = pred_cam_t_full[n] + kpts_2d = project_full_img(verts, cam_t, scaled_focal_length, img_size[n]) + + all_verts.append(verts) + all_cam_t.append(cam_t) + all_right.append(is_right) + all_joints.append(joints) + all_kpts.append(kpts_2d) + + + # Save all meshes to disk + if args.save_mesh: + camera_translation = cam_t.copy() + tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_PURPLE, is_right=is_right) + tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{n}.obj')) + + # Render front view + if len(all_verts) > 0: + misc_args = dict( + mesh_base_color=LIGHT_PURPLE, + scene_bg_color=(1, 1, 1), + focal_length=scaled_focal_length, + ) + cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=img_size[n], is_right=all_right, **misc_args) + + # Overlay image + input_img = img_cv2.astype(np.float32)[:,:,::-1]/255.0 + input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel + input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:] + final_img = 255*input_img_overlay[:, :, ::-1] + else: + final_img = img_cv2 + for i in range(len(bboxes)): + if right[i] == 0: + final_img = cv2.rectangle(final_img, (int(bboxes[i][0]), int(bboxes[i][1])), (int(bboxes[i][2]), int(bboxes[i][3])), (255, 100, 100), 2) + else: + final_img = cv2.rectangle(final_img, (int(bboxes[i][0]), int(bboxes[i][1])), (int(bboxes[i][2]), int(bboxes[i][3])), (100, 100, 255), 2) + cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}.jpg'), final_img) + +def project_full_img(points, cam_trans, focal_length, img_res): + camera_center = [img_res[0] / 2., img_res[1] / 2.] + K = torch.eye(3) + K[0,0] = focal_length + K[1,1] = focal_length + K[0,2] = camera_center[0] + K[1,2] = camera_center[1] + points = points + cam_trans + points = points / points[..., -1:] + + V_2d = (K @ points.T).T + return V_2d[..., :-1] + diff --git a/src/egoallo/inference_utils.py b/src/egoallo/inference_utils.py index 994a0a2..e3b69aa 100644 --- a/src/egoallo/inference_utils.py +++ b/src/egoallo/inference_utils.py @@ -62,7 +62,7 @@ class InferenceTrajectoryPaths: splat_path: Path | None @staticmethod - def find(traj_root: Path) -> InferenceTrajectoryPaths: + def find(traj_root: Path, detector: str="hamer") -> InferenceTrajectoryPaths: vrs_files = tuple(traj_root.glob("**/*.vrs")) assert len(vrs_files) == 1, f"Found {len(vrs_files)} VRS files!" @@ -71,8 +71,8 @@ def find(traj_root: Path) -> InferenceTrajectoryPaths: if len(points_paths) == 0: points_paths = tuple(traj_root.glob("**/global_points.csv.gz")) assert len(points_paths) == 1, f"Found {len(points_paths)} files!" - - hamer_outputs = traj_root / "hamer_outputs.pkl" + pickle_file_name = detector + "_outputs.pkl" + hamer_outputs = traj_root / pickle_file_name if not hamer_outputs.exists(): hamer_outputs = None From 2006d46a086aeb2b8ef334e5884117dccffecf31 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Tue, 29 Oct 2024 15:32:15 -0700 Subject: [PATCH 4/8] fix HamerHelper detector --- 2_run_hamer_on_vrs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/2_run_hamer_on_vrs.py b/2_run_hamer_on_vrs.py index fbc3ae5..4d59504 100644 --- a/2_run_hamer_on_vrs.py +++ b/2_run_hamer_on_vrs.py @@ -230,7 +230,7 @@ def run_wilor_and_save( shutil.rmtree(render_out_path, ignore_errors=True) render_out_path.mkdir(exist_ok=True) wilor_helper = WiLoRHelper(wilor_home) - hamer_helper = HamerHelper() + hamer_helper = HamerHelper(other_detector=True) # VRS data provider setup. provider = create_vrs_data_provider(str(vrs_path.absolute())) @@ -389,7 +389,7 @@ def run_aria_hamer_and_save( shutil.rmtree(render_out_path, ignore_errors=True) render_out_path.mkdir(exist_ok=True) - hamer_helper = HamerHelper() + hamer_helper = HamerHelper(other_detector=True) # VRS data provider setup. provider = create_vrs_data_provider(str(vrs_path.absolute())) From c0a79429f5a9e2166277d90935a64c91d1aae925 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Tue, 19 Nov 2024 11:25:05 -0800 Subject: [PATCH 5/8] wilor --- 2_run_hamer_on_vrs.py | 205 +-------------------------------- 3_aria_inference.py | 13 ++- 4_visualize_outputs.py | 11 +- src/egoallo/inference_utils.py | 1 - 4 files changed, 18 insertions(+), 212 deletions(-) diff --git a/2_run_hamer_on_vrs.py b/2_run_hamer_on_vrs.py index 4d59504..3bdba19 100644 --- a/2_run_hamer_on_vrs.py +++ b/2_run_hamer_on_vrs.py @@ -33,7 +33,7 @@ def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False, wilor Arguments: traj_root: The root directory of the trajectory. We assume that there's a VRS file in this directory. - detector: The detector to use. Can be "wilor", "aria", or "hamer". + detector: The detector to use. Can be "wilor" or "hamer". overwrite: If True, overwrite any existing HaMeR outputs. wilor_home: The path to the WiLoR home directory. Only used if `detector` is "wilor". @@ -47,12 +47,6 @@ def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False, wilor render_out_path = traj_root / "wilor_outputs_unrot_render" # This is just for debugging. pickle_out = traj_root / "wilor_outputs.pkl" run_wilor_and_save(vrs_path, pickle_out, render_out_path, overwrite, wilor_home) - elif detector == "aria": - render_out_path = traj_root / "hamer_aria_outputs_render" # This is just for debugging. - wrist_and_palm_poses_path = traj_root / "hand_tracking/wrist_and_palm_poses.csv" - online_path = traj_root / "slam/online_calibration.jsonl" - pickle_out = traj_root / "aria_outputs.pkl" - run_aria_hamer_and_save(vrs_path, pickle_out, render_out_path, wrist_and_palm_poses_path, online_path, overwrite) elif detector == "hamer": render_out_path = traj_root / "hamer_outputs_render" # This is just for debugging. pickle_out = traj_root / "hamer_outputs.pkl" @@ -376,203 +370,6 @@ def run_wilor_and_save( with open(pickle_out, "wb") as f: pickle.dump(outputs, f) - -def run_aria_hamer_and_save( - vrs_path: Path, pickle_out: Path, render_out_path: Path, wrist_and_palm_poses_path: Path, online_calib_path: Path, overwrite: bool -) -> None: - from aria_utils import per_image_hand_tracking,get_online_calib,x_y_around - if not overwrite: - assert not pickle_out.exists() - assert not render_out_path.exists() - else: - pickle_out.unlink(missing_ok=True) - shutil.rmtree(render_out_path, ignore_errors=True) - - render_out_path.mkdir(exist_ok=True) - hamer_helper = HamerHelper(other_detector=True) - - # VRS data provider setup. - provider = create_vrs_data_provider(str(vrs_path.absolute())) - assert isinstance(provider, VrsDataProvider) - rgb_stream_id = provider.get_stream_id_from_label("camera-rgb") - assert rgb_stream_id is not None - - num_images = provider.get_num_data(rgb_stream_id) - print(f"Found {num_images=}") - - # Get calibrations. - device_calib = provider.get_device_calibration() - assert device_calib is not None - camera_calib = device_calib.get_camera_calib("camera-rgb") - assert camera_calib is not None - pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450) - - # Compute camera extrinsics! - sophus_T_device_camera = device_calib.get_transform_device_sensor("camera-rgb") - sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor("camera-rgb") - assert sophus_T_device_camera is not None - assert sophus_T_cpf_camera is not None - T_device_cam = np.concatenate( - [ - sophus_T_device_camera.rotation().to_quat().squeeze(axis=0), - sophus_T_device_camera.translation().squeeze(axis=0), - ] - ) - T_cpf_cam = np.concatenate( - [ - sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0), - sophus_T_cpf_camera.translation().squeeze(axis=0), - ] - ) - assert T_device_cam.shape == T_cpf_cam.shape == (7,) - - # Dict from capture timestamp in nanoseconds to fields we care about. - detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} - detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} - - wrist_and_palm_poses_path = str(wrist_and_palm_poses_path) - online_calib_path = str(online_calib_path) - - rgb_calib = get_online_calib(online_calib_path, "camera-rgb") - wrist_and_palm_poses = mps.hand_tracking.read_wrist_and_palm_poses(wrist_and_palm_poses_path) - - pbar = tqdm(range(num_images)) - - for i in pbar: - image_data, image_data_record = provider.get_image_data_by_index( - rgb_stream_id, i - ) - undistorted_image = calibration.distort_by_calibration( - image_data.to_numpy_array(), pinhole, camera_calib - ) - - timestamp_ns = image_data_record.capture_timestamp_ns - l_existed, r_existed, l_point, r_point = per_image_hand_tracking(timestamp_ns, wrist_and_palm_poses, pinhole, camera_calib, rgb_calib) - if l_existed: - l_box = x_y_around(l_point[0], l_point[1],pinhole,offset=80) - else: - l_box = None - - if r_existed: - r_box = x_y_around(r_point[0], r_point[1],pinhole,offset=80) - else: - r_box = None - - - hamer_out_left, hamer_out_right = hamer_helper.get_det_from_boxes( - undistorted_image, - l_existed, - r_existed, - l_box, - r_box, - focal_length=450, - ) - - if hamer_out_left is None: - detections_left_wrt_cam[timestamp_ns] = None - else: - detections_left_wrt_cam[timestamp_ns] = { - "verts": hamer_out_left["verts"], - "keypoints_3d": hamer_out_left["keypoints_3d"], - "mano_hand_pose": hamer_out_left["mano_hand_pose"], - "mano_hand_betas": hamer_out_left["mano_hand_betas"], - "mano_hand_global_orient": hamer_out_left["mano_hand_global_orient"], - } - - if hamer_out_right is None: - detections_right_wrt_cam[timestamp_ns] = None - else: - detections_right_wrt_cam[timestamp_ns] = { - "verts": hamer_out_right["verts"], - "keypoints_3d": hamer_out_right["keypoints_3d"], - "mano_hand_pose": hamer_out_right["mano_hand_pose"], - "mano_hand_betas": hamer_out_right["mano_hand_betas"], - "mano_hand_global_orient": hamer_out_right["mano_hand_global_orient"], - } - - composited = undistorted_image - composited = hamer_helper.composite_detections( - composited, - hamer_out_left, - border_color=(255, 100, 100), - focal_length=450, - ) - composited = hamer_helper.composite_detections( - composited, - hamer_out_right, - border_color=(100, 100, 255), - focal_length=450, - ) - composited = put_text( - composited, - "L detections: " - + ( - "0" if hamer_out_left is None else str(hamer_out_left["verts"].shape[0]) - ), - 0, - color=(255, 100, 100), - font_scale=10.0 / 2880.0 * undistorted_image.shape[0], - ) - composited = put_text( - composited, - "R detections: " - + ( - "0" - if hamer_out_right is None - else str(hamer_out_right["verts"].shape[0]) - ), - 1, - color=(100, 100, 255), - font_scale=10.0 / 2880.0 * undistorted_image.shape[0], - ) - composited = put_text( - composited, - f"ns={timestamp_ns}", - 2, - color=(255, 255, 255), - font_scale=10.0 / 2880.0 * undistorted_image.shape[0], - ) - - print(f"Saving image {i:06d} to {render_out_path / f'{i:06d}.jpeg'}") - # bbox on undistorted image - if l_existed: - min_l_p_x, min_l_p_y, max_l_p_x, max_l_p_y = l_box - max_l_p_x, min_l_p_x, max_l_p_y, min_l_p_y = int(max_l_p_x), int(min_l_p_x), int(max_l_p_y), int(min_l_p_y) - - cv2.rectangle(composited, (max_l_p_x, max_l_p_y), (min_l_p_x, min_l_p_y), (255, 100, 100),2) - if r_existed: - min_r_p_x, min_r_p_y, max_r_p_x, max_r_p_y = r_box - max_r_p_x, min_r_p_x, max_r_p_y, min_r_p_y = int(max_r_p_x), int(min_r_p_x), int(max_r_p_y), int(min_r_p_y) - - cv2.rectangle(composited, (max_r_p_x, max_r_p_y), (min_r_p_x, min_r_p_y), (100, 100, 255),2) - - - iio.imwrite( - str(render_out_path / f"{i:06d}.jpeg"), - np.concatenate( - [ - # Darken input image, just for contrast... - (undistorted_image * 0.6).astype(np.uint8), - composited, - ], - axis=1, - ), - quality=90, - ) - - outputs = SavedHamerOutputs( - mano_faces_right=hamer_helper.get_mano_faces("right"), - mano_faces_left=hamer_helper.get_mano_faces("left"), - detections_right_wrt_cam=detections_right_wrt_cam, - detections_left_wrt_cam=detections_left_wrt_cam, - T_device_cam=T_device_cam, - T_cpf_cam=T_cpf_cam, - ) - with open(pickle_out, "wb") as f: - pickle.dump(outputs, f) - - - def put_text( image: np.ndarray, text: str, diff --git a/3_aria_inference.py b/3_aria_inference.py index fe00975..578314e 100644 --- a/3_aria_inference.py +++ b/3_aria_inference.py @@ -39,7 +39,7 @@ class Args: ... """ detector: str = "hamer" - """for choosing pkl file. choose from ["hamer", "aria", "wilor"]""" + """for choosing pkl file. choose from ["hamer", "wilor"]""" checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/") smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz") @@ -150,14 +150,17 @@ def main(args: Args) -> None: # Save outputs in case we want to visualize later. if args.save_traj: save_name = ( - args.detector - + time.strftime("%Y%m%d-%H%M%S") + time.strftime("%Y%m%d-%H%M%S") + f"_{args.start_index}-{args.start_index + args.traj_length}" ) - out_path = args.traj_root / "egoallo_outputs" / (save_name + ".npz") + if args.detector == "hamer": + egoallo_dir = "egoallo_outputs" + elif args.detector == "wilor": + egoallo_dir = "wilor_egoallo_outputs" + out_path = args.traj_root / egoallo_dir / (save_name + ".npz") out_path.parent.mkdir(parents=True, exist_ok=True) assert not out_path.exists() - (args.traj_root / "egoallo_outputs" / (save_name + "_args.yaml")).write_text( + (args.traj_root / egoallo_dir / (save_name + "_args.yaml")).write_text( yaml.dump(dataclasses.asdict(args)) ) diff --git a/4_visualize_outputs.py b/4_visualize_outputs.py index 2480e8c..d76e4da 100644 --- a/4_visualize_outputs.py +++ b/4_visualize_outputs.py @@ -31,6 +31,7 @@ def main( search_root_dir: Path, + detector: str = "hamer", smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"), ) -> None: """Visualization script for outputs from EgoAllo. @@ -48,9 +49,13 @@ def main( server.gui.configure_theme(dark_mode=True) def get_file_list(): + if detector == "hamer": + egoallo_outputs_dir = search_root_dir.glob("**/egoallo_outputs/*.npz") + elif detector == "wilor": + egoallo_outputs_dir = search_root_dir.glob("**/wilor_egoallo_outputs/*.npz") return ["None"] + sorted( str(p.relative_to(search_root_dir)) - for p in search_root_dir.glob("**/egoallo_outputs/*.npz") + for p in egoallo_outputs_dir ) options = get_file_list() @@ -86,6 +91,7 @@ def _(_) -> None: loop_cb = load_and_visualize( server, npz_path, + detector, body_model, device=device, ) @@ -100,6 +106,7 @@ def _(_) -> None: def load_and_visualize( server: viser.ViserServer, npz_path: Path, + detector: str, body_model: fncsmpl.SmplhModel, device: torch.device, ) -> Callable[[], int]: @@ -137,7 +144,7 @@ def load_and_visualize( # - outputs # - the npz file traj_dir = npz_path.resolve().parent.parent - paths = InferenceTrajectoryPaths.find(traj_dir) + paths = InferenceTrajectoryPaths.find(traj_dir,detector=detector) provider = create_vrs_data_provider(str(paths.vrs_file)) device_calib = provider.get_device_calibration() diff --git a/src/egoallo/inference_utils.py b/src/egoallo/inference_utils.py index e3b69aa..3975d26 100644 --- a/src/egoallo/inference_utils.py +++ b/src/egoallo/inference_utils.py @@ -65,7 +65,6 @@ class InferenceTrajectoryPaths: def find(traj_root: Path, detector: str="hamer") -> InferenceTrajectoryPaths: vrs_files = tuple(traj_root.glob("**/*.vrs")) assert len(vrs_files) == 1, f"Found {len(vrs_files)} VRS files!" - points_paths = tuple(traj_root.glob("**/semidense_points.csv.gz")) assert len(points_paths) <= 1, f"Found multiple points files! {points_paths}" if len(points_paths) == 0: From 33c982eaf0e2d5cd5aa5652590cc9612b75ecbc7 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Tue, 19 Nov 2024 12:23:10 -0800 Subject: [PATCH 6/8] add --detector wilor to script2,3,4 --- 2_run_hamer_on_vrs.py | 22 +- _hamer_helper.py | 834 --------------------------------- src/egoallo/inference_utils.py | 1 + 3 files changed, 11 insertions(+), 846 deletions(-) delete mode 100644 _hamer_helper.py diff --git a/2_run_hamer_on_vrs.py b/2_run_hamer_on_vrs.py index 3bdba19..f03e690 100644 --- a/2_run_hamer_on_vrs.py +++ b/2_run_hamer_on_vrs.py @@ -12,15 +12,12 @@ SavedHamerOutputs, SingleHandHamerOutputWrtCamera, ) - +from hamer_helper import HamerHelper from projectaria_tools.core import calibration from projectaria_tools.core.data_provider import ( VrsDataProvider, create_vrs_data_provider, ) -from _hamer_helper import HamerHelper -import projectaria_tools.core.mps as mps -from projectaria_tools.core.sensor_data import DEVICE_TIME from tqdm.auto import tqdm from egoallo.inference_utils import InferenceTrajectoryPaths @@ -40,32 +37,33 @@ def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False, wilor """ paths = InferenceTrajectoryPaths.find(traj_root) + vrs_path = paths.vrs_file assert vrs_path.exists() if detector == "wilor": - render_out_path = traj_root / "wilor_outputs_unrot_render" # This is just for debugging. pickle_out = traj_root / "wilor_outputs.pkl" + render_out_path = traj_root / "wilor_outputs_unrot_render" # This is just for debugging. run_wilor_and_save(vrs_path, pickle_out, render_out_path, overwrite, wilor_home) elif detector == "hamer": - render_out_path = traj_root / "hamer_outputs_render" # This is just for debugging. pickle_out = traj_root / "hamer_outputs.pkl" + render_out_path = traj_root / "hamer_outputs_render" # This is just for debugging. run_hamer_and_save(vrs_path, pickle_out, render_out_path, overwrite) else: raise ValueError(f"Unknown detector: {detector}") def run_hamer_and_save( - vrs_path: Path, pickle_out: Path, render_out_path: Path, overwrite: bool + vrs_path: Path, pickle_out: Path, hamer_render_out: Path, overwrite: bool ) -> None: if not overwrite: assert not pickle_out.exists() - assert not render_out_path.exists() + assert not hamer_render_out.exists() else: pickle_out.unlink(missing_ok=True) - shutil.rmtree(render_out_path, ignore_errors=True) + shutil.rmtree(hamer_render_out, ignore_errors=True) - render_out_path.mkdir(exist_ok=True) + hamer_render_out.mkdir(exist_ok=True) hamer_helper = HamerHelper() # VRS data provider setup. @@ -187,9 +185,9 @@ def run_hamer_and_save( font_scale=10.0 / 2880.0 * undistorted_image.shape[0], ) - print(f"Saving image {i:06d} to {render_out_path / f'{i:06d}.jpeg'}") + print(f"Saving image {i:06d} to {hamer_render_out / f'{i:06d}.jpeg'}") iio.imwrite( - str(render_out_path / f"{i:06d}.jpeg"), + str(hamer_render_out / f"{i:06d}.jpeg"), np.concatenate( [ # Darken input image, just for contrast... diff --git a/_hamer_helper.py b/_hamer_helper.py deleted file mode 100644 index 5a3c16b..0000000 --- a/_hamer_helper.py +++ /dev/null @@ -1,834 +0,0 @@ -import contextlib -import os -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Generator, Literal, TypedDict - -import imageio.v3 as iio -import numpy as np -import torch -import torch.utils.data -from hamer.datasets.vitdet_dataset import DEFAULT_MEAN, DEFAULT_STD, ViTDetDataset -from hamer.utils.mesh_renderer import create_raymond_lights -from hamer.utils.renderer import Renderer, cam_crop_to_full -from jaxtyping import Float, Int -from scipy.ndimage import binary_dilation -from torch import Tensor - - -class HandOutputsWrtCamera(TypedDict): - """Hand outputs with respect to the camera frame.""" - - verts: Float[np.ndarray, "num_hands 778 3"] - keypoints_3d: Float[np.ndarray, "num_hands 21 3"] - mano_hand_pose: Float[np.ndarray, "num_hands 15 3 3"] - mano_hand_betas: Float[np.ndarray, "num_hands 10"] - mano_hand_global_orient: Float[np.ndarray, "num_hands 1 3 3"] - faces: Int[np.ndarray, "mesh_faces 3"] - - -@contextlib.contextmanager -def _stopwatch(message: str): - print("[STOPWATCH]", message) - start = time.time() - yield - print("[STOPWATCH]", message, f"finished in {time.time() - start} seconds!") - - -@dataclass(frozen=True) -class _RawHamerOutputs: - """A typed wrapper for outputs from HaMeR.""" - - # Comments here are what I got when printing out the shapes of different - # HaMeR outputs. - - # pred_cam torch.Size([1, 3]) - pred_cam: Float[Tensor, "num_hands 3"] - # pred_mano_params global_orient torch.Size([1, 1, 3, 3]) - pred_mano_global_orient: Float[Tensor, "num_hands 1 3 3"] - # pred_mano_params hand_pose torch.Size([1, 15, 3, 3]) - pred_mano_hand_pose: Float[Tensor, "num_hands 15 3 3"] - # pred_mano_params betas torch.Size([1, 10]) - pred_mano_hand_betas: Float[Tensor, "num_hands 10"] - # pred_cam_t torch.Size([1, 3]) - pred_cam_t: Float[Tensor, "num_hands 3"] - - # focal length from model is ignored - # focal_length torch.Size([1, 2]) - # focal_length: Float[Tensor, "num_hands 2"] - - # pred_keypoints_3d torch.Size([1, 21, 3]) - pred_keypoints_3d: Float[Tensor, "num_hands 21 3"] - # pred_vertices torch.Size([1, 778, 3]) - pred_vertices: Float[Tensor, "num_hands 778 3"] - # pred_keypoints_2d torch.Size([1, 21, 2]) - pred_keypoints_2d: Float[Tensor, "num_hands 21 3"] - - pred_right: Float[Tensor, "num_hands"] - """A given hand is a right hand if this value is >0.5.""" - - # These aren't technically HaMeR outputs, but putting them here for convenience. - mano_faces_right: Tensor - mano_faces_left: Tensor - - -@contextlib.contextmanager -def temporary_cwd_context(x: Path) -> Generator[None, None, None]: - """Temporarily change our working directory.""" - d = os.getcwd() - os.chdir(x) - try: - yield - finally: - os.chdir(d) - - -class HamerHelper: - """Helper class for running HaMeR. Adapted from HaMeR demo script.""" - - def __init__(self,other_detector=False) -> None: - import hamer - from hamer.models import DEFAULT_CHECKPOINT, load_hamer - from vitpose_model import ViTPoseModel - - # HaMeR hardcodes a bunch of relative paths... - # Instead of modifying HaMeR we're going to hack this by temporarily changing our working directory :) - hamer_directory = Path(hamer.__file__).parent.parent - - with temporary_cwd_context(hamer_directory): - # Download and load checkpoints - # download_models(Path(hamer.__file__).parent.parent /CACHE_DIR_HAMER) - with _stopwatch("Loading HaMeR model..."): - model, model_cfg = load_hamer( - str(Path(hamer.__file__).parent.parent / DEFAULT_CHECKPOINT) - ) - - # Setup HaMeR model - with _stopwatch("Configuring HaMeR model..."): - device = torch.device("cuda") - model = model.to(device) - model.eval() - - # Load detector - import hamer - - if other_detector: - self._detector = None - else: - from detectron2.config import LazyConfig - from hamer.utils.utils_detectron2 import DefaultPredictor_Lazy - - with _stopwatch("Creating Detectron2 predictor..."): - cfg_path = ( - Path(hamer.__file__).parent - / "configs" - / "cascade_mask_rcnn_vitdet_h_75ep.py" - ) - detectron2_cfg = LazyConfig.load(str(cfg_path)) - detectron2_cfg.train.init_checkpoint = "https://dl.fbaipublicfiles.com/detectron2/ViTDet/COCO/cascade_mask_rcnn_vitdet_h/f328730692/model_final_f05665.pkl" # type: ignore - for i in range(3): - detectron2_cfg.model.roi_heads.box_predictors[ # type: ignore - i - ].test_score_thresh = 0.25 - detector = DefaultPredictor_Lazy(detectron2_cfg) - self._detector = detector - - # keypoint detector - with _stopwatch("Creating ViT pose model..."): - cpm = ViTPoseModel(device) - - self._model = model - self._model_cfg = model_cfg - self._cpm = cpm - self.device = device - - print("#" * 80) - print("#" * 80) - print("#" * 80) - print( - "Done setting up HaMeR! There were probably lots of errors, including a scary gigantic one about state dict stuff, but it's probably fine!" - ) - print("#" * 80) - print("#" * 80) - print("#" * 80) - - def get_default_focal_length(self, h: int, w: int) -> float: - """Get the default focal length for a given image size. - - This is how the HaMeR demo script computes the focal length... I don't - have a clear sense of the significance. We could ask George. - """ - return ( - self._model_cfg.EXTRA.FOCAL_LENGTH - / self._model_cfg.MODEL.IMAGE_SIZE - * max(h, w) - ) - - def get_det_from_boxes( - self, - image: Int[np.ndarray, "height width 3"], - ldetected: bool, - rdetected: bool, - l_box: np.ndarray, - r_box: np.ndarray, - focal_length: float | None = None, - rescale_factor: float = 2.0, - render_output_dir_for_testing: Path | None = None, - render_output_prefix_for_testing: str = "", - ) -> tuple[HandOutputsWrtCamera | None, HandOutputsWrtCamera | None]: - assert image.shape[-1] == 3 - - # image must be `np.uint8`, and in range [0, 255]. - assert image.dtype == np.uint8 - - # # Detectron expects BGR image. - # det_out = self._detector(image[:, :, ::-1]) - # det_instances = det_out["instances"] - # valid_idx = (det_instances.pred_classes == 0) & (det_instances.scores > 0.5) - # pred_bboxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy() - # pred_scores = det_instances.scores[valid_idx].cpu().numpy() - - - - # # Detect human keypoints for each person - # vitposes_out = self._cpm.predict_pose( - # image, - # [np.concatenate([pred_bboxes, pred_scores[:, None]], axis=1)], - # ) - - bboxes = [] - is_right = [] - - # # Use hands based on hand keypoint detections - # for vitposes in vitposes_out: - # left_hand_keyp = vitposes["keypoints"][-42:-21] - # right_hand_keyp = vitposes["keypoints"][-21:] - - # lbbox = None - # rbbox = None - - # # Rejecting not confident detections - # ldetect = rdetect = False - # keyp = left_hand_keyp - # valid = keyp[:, 2] > 0.5 - # if sum(valid) > 3: - # lbbox = [ - # keyp[valid, 0].min(), - # keyp[valid, 1].min(), - # keyp[valid, 0].max(), - # keyp[valid, 1].max(), - # ] - # ldetect = True - # keyp = right_hand_keyp - # valid = keyp[:, 2] > 0.5 - # if sum(valid) > 3: - # rbbox = [ - # keyp[valid, 0].min(), - # keyp[valid, 1].min(), - # keyp[valid, 0].max(), - # keyp[valid, 1].max(), - # ] - # rdetect = True - - # # suppressing - # if ldetect == True and rdetect == True: - # bboxes_dims = [ - # left_hand_keyp[:, 0].max() - left_hand_keyp[:, 0].min(), - # left_hand_keyp[:, 1].max() - left_hand_keyp[:, 1].min(), - # right_hand_keyp[:, 0].max() - right_hand_keyp[:, 0].min(), - # right_hand_keyp[:, 1].max() - right_hand_keyp[:, 1].min(), - # ] - # norm_side = max(bboxes_dims) - # keyp_dist = ( - # np.sqrt( - # np.sum( - # (right_hand_keyp[:, :2] - left_hand_keyp[:, :2]) ** 2, - # axis=1, - # ) - # ) - # / norm_side - # ) - # if np.mean(keyp_dist) < 0.5: - # if left_hand_keyp[0, 2] - right_hand_keyp[0, 2] > 0: - # assert lbbox is not None - # bboxes.append(lbbox) - # is_right.append(0) - # else: - # assert rbbox is not None - # bboxes.append(rbbox) - # is_right.append(1) - # else: - # assert lbbox is not None - # assert rbbox is not None - # bboxes.append(lbbox) - # is_right.append(0) - # bboxes.append(rbbox) - # is_right.append(1) - # elif ldetect == True: - # assert lbbox is not None - # bboxes.append(lbbox) - # is_right.append(0) - # elif rdetect == True: - # assert rbbox is not None - # bboxes.append(rbbox) - # is_right.append(1) - - # if len(bboxes) == 0: - # return None, None - - # boxes = np.stack(bboxes) - # right = np.stack(is_right) - - if ldetected: - bboxes.append(l_box) - is_right.append(0) - if rdetected: - bboxes.append(r_box) - is_right.append(1) - if len(bboxes) == 0: - return None, None - boxes = np.stack(bboxes) - right = np.stack(is_right) - - dataset = ViTDetDataset( - self._model_cfg, - # HaMeR expects BGR. - image[:, :, ::-1], - boxes, - right, - rescale_factor=rescale_factor, - ) - - # ViT detector will give us multiple detections. We want to run HaMeR - # on each. - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=8, shuffle=False, num_workers=0 - ) - outputs: list[_RawHamerOutputs] = [] - from hamer.utils import recursive_to - - for batch in dataloader: - batch: Any = recursive_to(batch, self.device) - with torch.no_grad(): - out = self._model.forward(batch) - - multiplier = 2 * batch["right"] - 1 - pred_cam = out["pred_cam"] - pred_cam[:, 1] = multiplier * pred_cam[:, 1] - box_center = batch["box_center"].float() - box_size = batch["box_size"].float() - img_size = batch["img_size"].float() - multiplier = 2 * batch["right"] - 1 - - if focal_length is None: - # All of the img_size rows should be the same. I think. - focal_length = float( - self.get_default_focal_length( - img_size[0, 0].item(), img_size[0, 1].item() - ) - ) - if isinstance(focal_length, int): - focal_length = float(focal_length) - assert isinstance(focal_length, float) - scaled_focal_length = focal_length - - pred_cam_t_full = cam_crop_to_full( - pred_cam, box_center, box_size, img_size, scaled_focal_length - ) - hamer_out = _RawHamerOutputs( - mano_faces_left=torch.from_numpy( - self._model.mano.faces[:, [0, 2, 1]].astype(np.int64) - ).to(device=self.device), - mano_faces_right=torch.from_numpy( - self._model.mano.faces.astype(np.int64) - ).to(device=self.device), - pred_cam=out["pred_cam"], - pred_mano_global_orient=out["pred_mano_params"]["global_orient"], - pred_mano_hand_pose=out["pred_mano_params"]["hand_pose"], - pred_mano_hand_betas=out["pred_mano_params"]["betas"], - pred_cam_t=pred_cam_t_full, - pred_keypoints_3d=out["pred_keypoints_3d"], - pred_vertices=out["pred_vertices"], - pred_keypoints_2d=out["pred_keypoints_2d"], - pred_right=batch["right"], - ) - - outputs.append(hamer_out) - - # Render the result. - if render_output_dir_for_testing: - renderer = Renderer(self._model_cfg, faces=self._model.mano.faces) - batch_size = batch["img"].shape[0] - for n in range(batch_size): - # Get filename from path img_path - person_id = int(batch["personid"][n]) - white_img = ( - torch.ones_like(batch["img"][n]).cpu() - - DEFAULT_MEAN[:, None, None] / 255 - ) / (DEFAULT_STD[:, None, None] / 255) - input_patch = batch["img"][n].cpu() * ( - DEFAULT_STD[:, None, None] / 255 - ) + (DEFAULT_MEAN[:, None, None] / 255) - input_patch = input_patch.permute(1, 2, 0).numpy() - - LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) - regression_img = renderer( - out["pred_vertices"][n].detach().cpu().numpy(), - out["pred_cam_t"][n].detach().cpu().numpy(), - batch["img"][n], - mesh_base_color=LIGHT_BLUE, - scene_bg_color=(1, 1, 1), - ) - - final_img = np.concatenate([input_patch, regression_img], axis=1) - - image_path = ( - render_output_dir_for_testing - / f"{render_output_prefix_for_testing}_hamer_{person_id}.png" - ) - print(f"Writing to {image_path}") - render_output_dir_for_testing.mkdir(exist_ok=True, parents=True) - iio.imwrite(image_path, (255 * final_img).astype(np.uint8)) - - # Add all verts and cams to list - verts = out["pred_vertices"][n].detach().cpu().numpy() - is_right = batch["right"][n].cpu().numpy() - verts[:, 0] = (2 * is_right - 1) * verts[:, 0] - - assert len(outputs) > 0 - stacked_outputs = _RawHamerOutputs( - **{ - field_name: torch.cat([getattr(x, field_name) for x in outputs], dim=0) - for field_name in vars(outputs[0]).keys() - }, - ) - # begin new brent stuff - verts = stacked_outputs.pred_vertices.numpy(force=True) - keypoints_3d = stacked_outputs.pred_keypoints_3d.numpy(force=True) - pred_cam_t = stacked_outputs.pred_cam_t.numpy(force=True) - mano_hand_pose = stacked_outputs.pred_mano_hand_pose.numpy(force=True) - mano_hand_betas = stacked_outputs.pred_mano_hand_betas.numpy(force=True) - R_camera_hand = stacked_outputs.pred_mano_global_orient.squeeze(dim=1).numpy( - force=True - ) - - is_right = (stacked_outputs.pred_right > 0.5).numpy(force=True) - is_left = ~is_right - - detections_right_wrt_cam: HandOutputsWrtCamera | None - if np.sum(is_right) == 0: - detections_right_wrt_cam = None - else: - detections_right_wrt_cam = { - "verts": verts[is_right] + pred_cam_t[is_right, None, :], - "keypoints_3d": keypoints_3d[is_right] + pred_cam_t[is_right, None, :], - "mano_hand_pose": mano_hand_pose[is_right], - "mano_hand_betas": mano_hand_betas[is_right], - "mano_hand_global_orient": R_camera_hand[is_right], - "faces": self.get_mano_faces("right"), - } - - detections_left_wrt_cam: HandOutputsWrtCamera | None - if np.sum(is_left) == 0: - detections_left_wrt_cam = None - else: - - def flip_rotmats(rotmats: np.ndarray) -> np.ndarray: - assert rotmats.shape[-2:] == (3, 3) - from viser import transforms - - logspace = transforms.SO3.from_matrix(rotmats).log() - logspace[..., 1] *= -1 - logspace[..., 2] *= -1 - return transforms.SO3.exp(logspace).as_matrix() - - detections_left_wrt_cam = { - "verts": verts[is_left] * np.array([-1, 1, 1]) - + pred_cam_t[is_left, None, :], - "keypoints_3d": keypoints_3d[is_left] * np.array([-1, 1, 1]) - + pred_cam_t[is_left, None, :], - "mano_hand_pose": flip_rotmats(mano_hand_pose[is_left]), - "mano_hand_betas": mano_hand_betas[is_left], - "mano_hand_global_orient": flip_rotmats(R_camera_hand[is_left]), - "faces": self.get_mano_faces("left"), - } - # end new brent stuff - return detections_left_wrt_cam, detections_right_wrt_cam - - - def look_for_hands( - self, - image: Int[np.ndarray, "height width 3"], - focal_length: float | None = None, - rescale_factor: float = 2.0, - render_output_dir_for_testing: Path | None = None, - render_output_prefix_for_testing: str = "", - ) -> tuple[HandOutputsWrtCamera | None, HandOutputsWrtCamera | None]: - """Look for hands. - - Arguments: - image: Image to look for hands in. Expects uint8, in range [0, 255]. - focal_length: Focal length of camera, used for 3D coordinates. - rescale_factor: Rescale factor for running ViT detector. I think 2 is fine, probably. - render_output_dir: Directory to render out detections to. Mostly this is used for testing. Doesn't do any rendering - """ - assert image.shape[-1] == 3 - - # image must be `np.uint8`, and in range [0, 255]. - assert image.dtype == np.uint8 - - # Detectron expects BGR image. - det_out = self._detector(image[:, :, ::-1]) - det_instances = det_out["instances"] - valid_idx = (det_instances.pred_classes == 0) & (det_instances.scores > 0.5) - pred_bboxes = det_instances.pred_boxes.tensor[valid_idx].cpu().numpy() - pred_scores = det_instances.scores[valid_idx].cpu().numpy() - - # Detect human keypoints for each person - vitposes_out = self._cpm.predict_pose( - image, - [np.concatenate([pred_bboxes, pred_scores[:, None]], axis=1)], - ) - - bboxes = [] - is_right = [] - - # Use hands based on hand keypoint detections - for vitposes in vitposes_out: - left_hand_keyp = vitposes["keypoints"][-42:-21] - right_hand_keyp = vitposes["keypoints"][-21:] - - lbbox = None - rbbox = None - - # Rejecting not confident detections - ldetect = rdetect = False - keyp = left_hand_keyp - valid = keyp[:, 2] > 0.5 - if sum(valid) > 3: - lbbox = [ - keyp[valid, 0].min(), - keyp[valid, 1].min(), - keyp[valid, 0].max(), - keyp[valid, 1].max(), - ] - ldetect = True - keyp = right_hand_keyp - valid = keyp[:, 2] > 0.5 - if sum(valid) > 3: - rbbox = [ - keyp[valid, 0].min(), - keyp[valid, 1].min(), - keyp[valid, 0].max(), - keyp[valid, 1].max(), - ] - rdetect = True - - # suppressing - if ldetect == True and rdetect == True: - bboxes_dims = [ - left_hand_keyp[:, 0].max() - left_hand_keyp[:, 0].min(), - left_hand_keyp[:, 1].max() - left_hand_keyp[:, 1].min(), - right_hand_keyp[:, 0].max() - right_hand_keyp[:, 0].min(), - right_hand_keyp[:, 1].max() - right_hand_keyp[:, 1].min(), - ] - norm_side = max(bboxes_dims) - keyp_dist = ( - np.sqrt( - np.sum( - (right_hand_keyp[:, :2] - left_hand_keyp[:, :2]) ** 2, - axis=1, - ) - ) - / norm_side - ) - if np.mean(keyp_dist) < 0.5: - if left_hand_keyp[0, 2] - right_hand_keyp[0, 2] > 0: - assert lbbox is not None - bboxes.append(lbbox) - is_right.append(0) - else: - assert rbbox is not None - bboxes.append(rbbox) - is_right.append(1) - else: - assert lbbox is not None - assert rbbox is not None - bboxes.append(lbbox) - is_right.append(0) - bboxes.append(rbbox) - is_right.append(1) - elif ldetect == True: - assert lbbox is not None - bboxes.append(lbbox) - is_right.append(0) - elif rdetect == True: - assert rbbox is not None - bboxes.append(rbbox) - is_right.append(1) - - if len(bboxes) == 0: - return None, None - - boxes = np.stack(bboxes) - right = np.stack(is_right) - - dataset = ViTDetDataset( - self._model_cfg, - # HaMeR expects BGR. - image[:, :, ::-1], - boxes, - right, - rescale_factor=rescale_factor, - ) - - # ViT detector will give us multiple detections. We want to run HaMeR - # on each. - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=8, shuffle=False, num_workers=0 - ) - outputs: list[_RawHamerOutputs] = [] - from hamer.utils import recursive_to - - for batch in dataloader: - batch: Any = recursive_to(batch, self.device) - with torch.no_grad(): - out = self._model.forward(batch) - - multiplier = 2 * batch["right"] - 1 - pred_cam = out["pred_cam"] - pred_cam[:, 1] = multiplier * pred_cam[:, 1] - box_center = batch["box_center"].float() - box_size = batch["box_size"].float() - img_size = batch["img_size"].float() - multiplier = 2 * batch["right"] - 1 - - if focal_length is None: - # All of the img_size rows should be the same. I think. - focal_length = float( - self.get_default_focal_length( - img_size[0, 0].item(), img_size[0, 1].item() - ) - ) - if isinstance(focal_length, int): - focal_length = float(focal_length) - assert isinstance(focal_length, float) - scaled_focal_length = focal_length - - pred_cam_t_full = cam_crop_to_full( - pred_cam, box_center, box_size, img_size, scaled_focal_length - ) - hamer_out = _RawHamerOutputs( - mano_faces_left=torch.from_numpy( - self._model.mano.faces[:, [0, 2, 1]].astype(np.int64) - ).to(device=self.device), - mano_faces_right=torch.from_numpy( - self._model.mano.faces.astype(np.int64) - ).to(device=self.device), - pred_cam=out["pred_cam"], - pred_mano_global_orient=out["pred_mano_params"]["global_orient"], - pred_mano_hand_pose=out["pred_mano_params"]["hand_pose"], - pred_mano_hand_betas=out["pred_mano_params"]["betas"], - pred_cam_t=pred_cam_t_full, - pred_keypoints_3d=out["pred_keypoints_3d"], - pred_vertices=out["pred_vertices"], - pred_keypoints_2d=out["pred_keypoints_2d"], - pred_right=batch["right"], - ) - - outputs.append(hamer_out) - - # Render the result. - if render_output_dir_for_testing: - renderer = Renderer(self._model_cfg, faces=self._model.mano.faces) - batch_size = batch["img"].shape[0] - for n in range(batch_size): - # Get filename from path img_path - person_id = int(batch["personid"][n]) - white_img = ( - torch.ones_like(batch["img"][n]).cpu() - - DEFAULT_MEAN[:, None, None] / 255 - ) / (DEFAULT_STD[:, None, None] / 255) - input_patch = batch["img"][n].cpu() * ( - DEFAULT_STD[:, None, None] / 255 - ) + (DEFAULT_MEAN[:, None, None] / 255) - input_patch = input_patch.permute(1, 2, 0).numpy() - - LIGHT_BLUE = (0.65098039, 0.74117647, 0.85882353) - regression_img = renderer( - out["pred_vertices"][n].detach().cpu().numpy(), - out["pred_cam_t"][n].detach().cpu().numpy(), - batch["img"][n], - mesh_base_color=LIGHT_BLUE, - scene_bg_color=(1, 1, 1), - ) - - final_img = np.concatenate([input_patch, regression_img], axis=1) - - image_path = ( - render_output_dir_for_testing - / f"{render_output_prefix_for_testing}_hamer_{person_id}.png" - ) - print(f"Writing to {image_path}") - render_output_dir_for_testing.mkdir(exist_ok=True, parents=True) - iio.imwrite(image_path, (255 * final_img).astype(np.uint8)) - - # Add all verts and cams to list - verts = out["pred_vertices"][n].detach().cpu().numpy() - is_right = batch["right"][n].cpu().numpy() - verts[:, 0] = (2 * is_right - 1) * verts[:, 0] - - assert len(outputs) > 0 - stacked_outputs = _RawHamerOutputs( - **{ - field_name: torch.cat([getattr(x, field_name) for x in outputs], dim=0) - for field_name in vars(outputs[0]).keys() - }, - ) - # begin new brent stuff - verts = stacked_outputs.pred_vertices.numpy(force=True) - keypoints_3d = stacked_outputs.pred_keypoints_3d.numpy(force=True) - pred_cam_t = stacked_outputs.pred_cam_t.numpy(force=True) - mano_hand_pose = stacked_outputs.pred_mano_hand_pose.numpy(force=True) - mano_hand_betas = stacked_outputs.pred_mano_hand_betas.numpy(force=True) - R_camera_hand = stacked_outputs.pred_mano_global_orient.squeeze(dim=1).numpy( - force=True - ) - - is_right = (stacked_outputs.pred_right > 0.5).numpy(force=True) - is_left = ~is_right - - detections_right_wrt_cam: HandOutputsWrtCamera | None - if np.sum(is_right) == 0: - detections_right_wrt_cam = None - else: - detections_right_wrt_cam = { - "verts": verts[is_right] + pred_cam_t[is_right, None, :], - "keypoints_3d": keypoints_3d[is_right] + pred_cam_t[is_right, None, :], - "mano_hand_pose": mano_hand_pose[is_right], - "mano_hand_betas": mano_hand_betas[is_right], - "mano_hand_global_orient": R_camera_hand[is_right], - "faces": self.get_mano_faces("right"), - } - - detections_left_wrt_cam: HandOutputsWrtCamera | None - if np.sum(is_left) == 0: - detections_left_wrt_cam = None - else: - - def flip_rotmats(rotmats: np.ndarray) -> np.ndarray: - assert rotmats.shape[-2:] == (3, 3) - from viser import transforms - - logspace = transforms.SO3.from_matrix(rotmats).log() - logspace[..., 1] *= -1 - logspace[..., 2] *= -1 - return transforms.SO3.exp(logspace).as_matrix() - - detections_left_wrt_cam = { - "verts": verts[is_left] * np.array([-1, 1, 1]) - + pred_cam_t[is_left, None, :], - "keypoints_3d": keypoints_3d[is_left] * np.array([-1, 1, 1]) - + pred_cam_t[is_left, None, :], - "mano_hand_pose": flip_rotmats(mano_hand_pose[is_left]), - "mano_hand_betas": mano_hand_betas[is_left], - "mano_hand_global_orient": flip_rotmats(R_camera_hand[is_left]), - "faces": self.get_mano_faces("left"), - } - # end new brent stuff - return detections_left_wrt_cam, detections_right_wrt_cam - - def get_mano_faces(self, side: Literal["left", "right"]) -> np.ndarray: - if side == "left": - return self._model.mano.faces[:, [0, 2, 1]].copy() - else: - return self._model.mano.faces.copy() - - def render_detection( - self, - output_dict: HandOutputsWrtCamera, - hand_index: int, - h: int, - w: int, - focal_length: float | None = None, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Render to a tuple of (RGB, depth, mask). For testing.""" - import pyrender - import trimesh - - if focal_length is None: - focal_length = self.get_default_focal_length(h, w) - - render_res = (h, w) - renderer = pyrender.OffscreenRenderer( - viewport_width=render_res[1], viewport_height=render_res[0], point_size=1.0 - ) - material = pyrender.MetallicRoughnessMaterial( - metallicFactor=0.0, alphaMode="OPAQUE", baseColorFactor=(1.0, 1.0, 0.9, 1.0) - ) - - vertices = output_dict["verts"][hand_index] - faces = output_dict["faces"] - - mesh = trimesh.Trimesh(vertices.copy(), faces.copy()) - mesh = pyrender.Mesh.from_trimesh(mesh, material=material) - - scene = pyrender.Scene( - bg_color=[1.0, 1.0, 1.0, 0.0], ambient_light=(0.3, 0.3, 0.3) - ) - scene.add(mesh, "mesh") - - camera_center = [render_res[1] / 2.0, render_res[0] / 2.0] - camera = pyrender.IntrinsicsCamera( - fx=focal_length, - fy=focal_length, - cx=camera_center[0], - cy=camera_center[1], - zfar=1e12, - znear=0.001, - ) - - light_nodes = create_raymond_lights() - for node in light_nodes: - scene.add_node(node) - - # Create camera node and add it to pyRender scene - camera_pose = np.eye(4) - camera_pose[1:3, :] *= -1 # flip the y and z axes to match opengl - camera_node = pyrender.Node(camera=camera, matrix=camera_pose) - scene.add_node(camera_node) - - color, rend_depth = renderer.render(scene, flags=pyrender.RenderFlags.RGBA) # type: ignore - mask = color[..., -1] > 0 - return color[..., :3], rend_depth, mask - - def composite_detections( - self, - image: np.ndarray, - detections: HandOutputsWrtCamera | None, - border_color: tuple[int, int, int], - focal_length: float | None = None, - ) -> np.ndarray: - """Render some hand detections on top of an image. Returns an updated image.""" - if detections is None: - return image - - h, w = image.shape[:2] - - for index in range(detections["verts"].shape[0]): - print(index) - render_rgb, _, render_mask = self.render_detection( - detections, hand_index=index, h=h, w=w, focal_length=focal_length - ) - border_width = 15 - image = np.where( - binary_dilation( - render_mask, np.ones((border_width, border_width), dtype=bool) - )[:, :, None], - np.zeros_like(render_rgb) + np.array(border_color, dtype=np.uint8), - image, - ) - image = np.where(render_mask[:, :, None], render_rgb, image) - - return image diff --git a/src/egoallo/inference_utils.py b/src/egoallo/inference_utils.py index 3975d26..e3b69aa 100644 --- a/src/egoallo/inference_utils.py +++ b/src/egoallo/inference_utils.py @@ -65,6 +65,7 @@ class InferenceTrajectoryPaths: def find(traj_root: Path, detector: str="hamer") -> InferenceTrajectoryPaths: vrs_files = tuple(traj_root.glob("**/*.vrs")) assert len(vrs_files) == 1, f"Found {len(vrs_files)} VRS files!" + points_paths = tuple(traj_root.glob("**/semidense_points.csv.gz")) assert len(points_paths) <= 1, f"Found multiple points files! {points_paths}" if len(points_paths) == 0: From fc0f804f42a0deb8848fc8e909517fc441b80e70 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Tue, 19 Nov 2024 13:53:41 -0800 Subject: [PATCH 7/8] fix wilor_home --- _wilor_helper.py | 46 ++++++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/_wilor_helper.py b/_wilor_helper.py index 2ae7143..f0dab6c 100644 --- a/_wilor_helper.py +++ b/_wilor_helper.py @@ -6,14 +6,12 @@ import numpy as np from dataclasses import dataclass from torch import Tensor -from typing import Any, Generator, Literal, TypedDict - -from wilor.models import load_wilor -from wilor.utils import recursive_to -from wilor.datasets.vitdet_dataset import ViTDetDataset -from wilor.utils.renderer import Renderer, cam_crop_to_full +from typing import Literal, TypedDict from ultralytics import YOLO from jaxtyping import Float, Int + +import sys + LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353) @@ -67,16 +65,32 @@ class _RawHamerOutputs: class WiLoRHelper: def __init__(self, wilor_home: str="./"): - model, model_cfg = load_wilor(wilor_home, checkpoint_path = './pretrained_models/wilor_final.ckpt' , - cfg_path= './pretrained_models/model_config.yaml') - detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt')) - # Setup the renderer - self.renderer = Renderer(model_cfg, faces=model.mano.faces) - - self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') - self._model = model.to(self.device) - self._model_cfg = model_cfg - self._detector = detector.to(self.device) + sys.path.append(wilor_home) + global load_wilor + global recursive_to + global ViTDetDataset + global Renderer + global cam_crop_to_full + + from wilor.models import load_wilor + from wilor.utils import recursive_to + from wilor.datasets.vitdet_dataset import ViTDetDataset + from wilor.utils.renderer import Renderer, cam_crop_to_full + checkpoint_path = os.path.join(wilor_home, 'pretrained_models/wilor_final.ckpt') + cfg_path = os.path.join(wilor_home, 'pretrained_models/model_config.yaml') + original_dir = os.getcwd() + os.chdir(wilor_home) + model, model_cfg = load_wilor(checkpoint_path = checkpoint_path , + cfg_path= cfg_path) + os.chdir(original_dir) + detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt')) + # Setup the renderer + self.renderer = Renderer(model_cfg, faces=model.mano.faces) + + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self._model = model.to(self.device) + self._model_cfg = model_cfg + self._detector = detector.to(self.device) def look_for_hands( self, From e6914a65b49ee6864893361e1f9f9641a59c1381 Mon Sep 17 00:00:00 2001 From: woodenbirds <1979309725@qq.com> Date: Tue, 19 Nov 2024 13:58:35 -0800 Subject: [PATCH 8/8] delete aria_utils --- aria_utils.py | 211 -------------------------------------------------- 1 file changed, 211 deletions(-) delete mode 100755 aria_utils.py diff --git a/aria_utils.py b/aria_utils.py deleted file mode 100755 index 50f827d..0000000 --- a/aria_utils.py +++ /dev/null @@ -1,211 +0,0 @@ -import pandas as pd -import numpy as np -import matplotlib.pyplot as plt -import projectaria_tools.core.mps as mps -from projectaria_tools.core import data_provider,calibration -# from projectaria_tools.core.mps.utils import get_nearest_pose -from projectaria_tools.core.stream_id import StreamId -from projectaria_tools.core.sensor_data import DEVICE_TIME,CLOSEST -from projectaria_tools.core.mps.utils import get_nearest_wrist_and_palm_pose -import os - -def x_y_rot90(x, y, w, h): - return h-y, x - -def x_y_undistort(x, y, w, h, pinhole, calib): - x, y = int(x), int(y) - # a zero numpy array of shape (w, h) with coordinate (x,y) having value 1 - point = np.zeros((w, h)) - offset = 2 - for i in range(x-offset, x+offset+1): - for j in range(y-offset, y+offset+1): - if i>=0 and i=0 and j