diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py index 1939fd91..08e0e868 100644 --- a/egomimic/rldb/embodiment/human.py +++ b/egomimic/rldb/embodiment/human.py @@ -1,19 +1,24 @@ from __future__ import annotations +from typing import Literal + from egomimic.rldb.embodiment.embodiment import Embodiment -from egomimic.rldb.embodiment.eva import ( - _viz_batch_palm_axes, - _viz_batch_palm_traj, -) from egomimic.rldb.zarr.action_chunk_transforms import ( ActionChunkCoordinateFrameTransform, ConcatKeys, DeleteKeys, InterpolatePose, PoseCoordinateFrameTransform, + Reshape, Transform, XYZWXYZ_to_XYZYPR, ) +from egomimic.utils.type_utils import _to_numpy +from egomimic.utils.viz_utils import ( + _viz_axes, + _viz_keypoints, + _viz_traj, +) class Human(Embodiment): @@ -22,75 +27,202 @@ class Human(Embodiment): ACTION_STRIDE = 3 @classmethod - def get_transform_list(cls) -> list[Transform]: - return _build_aria_bimanual_transform_list(stride=cls.ACTION_STRIDE) + def get_transform_list( + cls, mode: Literal["cartesian", "keypoints"] + ) -> list[Transform]: + if mode == "cartesian": + return _build_aria_cartesian_bimanual_transform_list( + stride=cls.ACTION_STRIDE + ) + elif mode == "keypoints": + return _build_aria_keypoints_bimanual_transform_list( + stride=cls.ACTION_STRIDE + ) + else: + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints'." + ) @classmethod - def viz_transformed_batch(cls, batch, mode=""): - image_key = cls.VIZ_IMAGE_KEY - action_key = "actions_cartesian" + def viz_transformed_batch( + cls, + batch, + mode=Literal["traj", "axes", "keypoints"], + action_key="actions_cartesian", + image_key=None, + ): + image_key = image_key or cls.VIZ_IMAGE_KEY + action_key = action_key or "actions_cartesian" intrinsics_key = cls.VIZ_INTRINSICS_KEY - mode = (mode or "palm_traj").lower() + mode = (mode or "traj").lower() + images = _to_numpy(batch[image_key][0]) + actions = _to_numpy(batch[action_key][0]) - if mode == "palm_traj": - return _viz_batch_palm_traj( - batch=batch, - image_key=image_key, - action_key=action_key, + return cls.viz( + images=images, actions=actions, mode=mode, intrinsics_key=intrinsics_key + ) + + @classmethod + def viz( + cls, + images, + actions, + mode=Literal["traj", "axes", "keypoints"], + intrinsics_key=None, + ): + intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY + if mode == "traj": + return _viz_traj( + images=images, + actions=actions, intrinsics_key=intrinsics_key, ) - if mode == "palm_axes": - return _viz_batch_palm_axes( - batch=batch, - image_key=image_key, - action_key=action_key, + if mode == "axes": + return _viz_axes( + images=images, + actions=actions, intrinsics_key=intrinsics_key, ) if mode == "keypoints": - raise NotImplementedError( - "mode='keypoints' is reserved and not implemented yet." + return _viz_keypoints( + images=images, + actions=actions, + intrinsics_key=intrinsics_key, + edges=cls.FINGER_EDGES, + colors=cls.FINGER_COLORS, + edge_ranges=cls.FINGER_EDGE_RANGES, ) - raise ValueError( f"Unsupported mode '{mode}'. Expected one of: " - f"('palm_traj', 'palm_axes', 'keypoints')." + f"('traj', 'axes', 'keypoints')." ) @classmethod - def get_keymap(cls): - return { - cls.VIZ_IMAGE_KEY: { - "key_type": "camera_keys", - "zarr_key": "images.front_1", - }, - "right.action_ee_pose": { - "key_type": "action_keys", - "zarr_key": "right.obs_ee_pose", - "horizon": 30, - }, - "left.action_ee_pose": { - "key_type": "action_keys", - "zarr_key": "left.obs_ee_pose", - "horizon": 30, - }, - "right.obs_ee_pose": { - "key_type": "proprio_keys", - "zarr_key": "right.obs_ee_pose", - }, - "left.obs_ee_pose": { - "key_type": "proprio_keys", - "zarr_key": "left.obs_ee_pose", - }, - "obs_head_pose": { - "key_type": "proprio_keys", - "zarr_key": "obs_head_pose", - }, - } + def get_keymap(cls, mode: Literal["cartesian", "keypoints"]): + if mode == "cartesian": + key_map = { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "right.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "right.obs_ee_pose", + "horizon": 30, + }, + "left.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "left.obs_ee_pose", + "horizon": 30, + }, + "right.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_ee_pose", + }, + "left.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_ee_pose", + }, + "obs_head_pose": { + "key_type": "proprio_keys", + "zarr_key": "obs_head_pose", + }, + } + elif mode == "keypoints": + key_map = { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "left.action_keypoints": { + "key_type": "action_keys", + "zarr_key": "left.obs_keypoints", + "horizon": 30, + }, + "right.action_keypoints": { + "key_type": "action_keys", + "zarr_key": "right.obs_keypoints", + "horizon": 30, + }, + "left.action_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_wrist_pose", + "horizon": 30, + }, + "right.action_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_wrist_pose", + "horizon": 30, + }, + "left.obs_keypoints": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_keypoints", + }, + "right.obs_keypoints": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_keypoints", + }, + "left.obs_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_wrist_pose", + }, + "right.obs_wrist_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_wrist_pose", + }, + "obs_head_pose": { + "key_type": "proprio_keys", + "zarr_key": "obs_head_pose", + }, + } + else: + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: 'cartesian', 'keypoints'." + ) + return key_map class Aria(Human): VIZ_INTRINSICS_KEY = "base" ACTION_STRIDE = 3 + FINGER_EDGES = [ + ( + 5, + 6, + ), + (6, 7), + (7, 0), # thumb + (5, 8), + (8, 9), + (9, 10), + (9, 1), # index + (5, 11), + (11, 12), + (12, 13), + (13, 2), # middle + (5, 14), + (14, 15), + (15, 16), + (16, 3), # ring + (5, 17), + (17, 18), + (18, 19), + (19, 4), # pinky + ] + FINGER_COLORS = { + "thumb": (255, 100, 100), # red + "index": (100, 255, 100), # green + "middle": (100, 100, 255), # blue + "ring": (255, 255, 100), # yellow + "pinky": (255, 100, 255), # magenta + } + FINGER_EDGE_RANGES = [ + ("thumb", 0, 3), + ("index", 3, 6), + ("middle", 6, 9), + ("ring", 9, 12), + ("pinky", 12, 15), + ] class Scale(Human): @@ -103,7 +235,202 @@ class Mecka(Human): ACTION_STRIDE = 1 -def _build_aria_bimanual_transform_list( +def _build_aria_keypoints_bimanual_transform_list( + *, + target_world: str = "obs_head_pose", + target_world_ypr: str = "obs_head_pose_ypr", + target_world_is_quat: bool = True, + left_keypoints_action_world: str = "left.action_keypoints", + right_keypoints_action_world: str = "right.action_keypoints", + left_keypoints_obs_pose: str = "left.obs_keypoints", + right_keypoints_obs_pose: str = "right.obs_keypoints", + left_keypoints_action_headframe: str = "left.action_keypoints_headframe", + right_keypoints_action_headframe: str = "right.action_keypoints_headframe", + left_keypoints_obs_headframe: str = "left.obs_keypoints_headframe", + right_keypoints_obs_headframe: str = "right.obs_keypoints_headframe", + left_wrist_action_world: str = "left.action_wrist_pose", + right_wrist_action_world: str = "right.action_wrist_pose", + left_wrist_obs_pose: str = "left.obs_wrist_pose", + right_wrist_obs_pose: str = "right.obs_wrist_pose", + left_wrist_action_headframe: str = "left.action_wrist_pose_headframe", + right_wrist_action_headframe: str = "right.action_wrist_pose_headframe", + left_wrist_obs_headframe: str = "left.obs_wrist_pose_headframe", + right_wrist_obs_headframe: str = "right.obs_wrist_pose_headframe", + delete_target_world: bool = True, + chunk_length: int = 100, + stride: int = 3, +) -> list[Transform]: + keys_to_delete = list( + { + left_keypoints_action_world, + right_keypoints_action_world, + left_keypoints_obs_pose, + right_keypoints_obs_pose, + left_wrist_action_world, + right_wrist_action_world, + left_wrist_obs_pose, + right_wrist_obs_pose, + left_keypoints_action_headframe, + right_keypoints_action_headframe, + left_keypoints_obs_headframe, + right_keypoints_obs_headframe, + left_wrist_action_headframe, + right_wrist_action_headframe, + left_wrist_obs_headframe, + right_wrist_obs_headframe, + } + ) + if delete_target_world: + keys_to_delete.append(target_world) + if target_world_is_quat: + keys_to_delete.append(target_world_ypr) + transform_list: list[Transform] = [ + Reshape( + input_key=left_keypoints_action_world, + output_key=left_keypoints_action_world, + shape=(30, 21, 3), + ), + Reshape( + input_key=right_keypoints_action_world, + output_key=right_keypoints_action_world, + shape=(30, 21, 3), + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=left_keypoints_action_world, + transformed_key_name=left_keypoints_action_headframe, + mode="xyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=right_keypoints_action_world, + transformed_key_name=right_keypoints_action_headframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_pose, + output_key=left_keypoints_obs_pose, + shape=(21, 3), + ), + Reshape( + input_key=right_keypoints_obs_pose, + output_key=right_keypoints_obs_pose, + shape=(21, 3), + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=left_keypoints_obs_pose, + transformed_key_name=left_keypoints_obs_headframe, + mode="xyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=right_keypoints_obs_pose, + transformed_key_name=right_keypoints_obs_headframe, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_obs_headframe, + output_key=left_keypoints_obs_headframe, + shape=(63,), + ), + Reshape( + input_key=right_keypoints_obs_headframe, + output_key=right_keypoints_obs_headframe, + shape=(63,), + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_keypoints_action_headframe, + output_action_key=left_keypoints_action_headframe, + stride=stride, + mode="xyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_keypoints_action_headframe, + output_action_key=right_keypoints_action_headframe, + stride=stride, + mode="xyz", + ), + Reshape( + input_key=left_keypoints_action_headframe, + output_key=left_keypoints_action_headframe, + shape=(chunk_length, 63), + ), + Reshape( + input_key=right_keypoints_action_headframe, + output_key=right_keypoints_action_headframe, + shape=(chunk_length, 63), + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=left_wrist_action_world, + transformed_key_name=left_wrist_action_headframe, + mode="xyzwxyz", + ), + ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=right_wrist_action_world, + transformed_key_name=right_wrist_action_headframe, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=left_wrist_obs_pose, + transformed_key_name=left_wrist_obs_headframe, + mode="xyzwxyz", + ), + PoseCoordinateFrameTransform( + target_world=target_world, + pose_world=right_wrist_obs_pose, + transformed_key_name=right_wrist_obs_headframe, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_wrist_action_headframe, + output_action_key=left_wrist_action_headframe, + stride=stride, + mode="xyzwxyz", + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_wrist_action_headframe, + output_action_key=right_wrist_action_headframe, + stride=stride, + mode="xyzwxyz", + ), + ] + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_wrist_action_headframe, + left_keypoints_action_headframe, + right_wrist_action_headframe, + right_keypoints_action_headframe, + ], + new_key_name="actions_keypoints", + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_wrist_obs_headframe, + left_keypoints_obs_headframe, + right_wrist_obs_headframe, + right_keypoints_obs_headframe, + ], + new_key_name="observations.state.keypoints", + delete_old_keys=True, + ), + DeleteKeys(keys_to_delete=keys_to_delete), + ] + ) + return transform_list + + +def _build_aria_cartesian_bimanual_transform_list( *, target_world: str = "obs_head_pose", target_world_ypr: str = "obs_head_pose_ypr", @@ -147,39 +474,39 @@ def _build_aria_bimanual_transform_list( target_world=target_pose_key, chunk_world=left_action_world, transformed_key_name=left_action_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), ActionChunkCoordinateFrameTransform( target_world=target_pose_key, chunk_world=right_action_world, transformed_key_name=right_action_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), PoseCoordinateFrameTransform( target_world=target_pose_key, pose_world=left_obs_pose, transformed_key_name=left_obs_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), PoseCoordinateFrameTransform( target_world=target_pose_key, pose_world=right_obs_pose, transformed_key_name=right_obs_headframe, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), InterpolatePose( new_chunk_length=chunk_length, action_key=left_action_headframe, output_action_key=left_action_headframe, stride=stride, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), InterpolatePose( new_chunk_length=chunk_length, action_key=right_action_headframe, output_action_key=right_action_headframe, stride=stride, - is_quat=target_world_is_quat, + mode="xyzwxyz", ), ] diff --git a/egomimic/rldb/zarr/action_chunk_transforms.py b/egomimic/rldb/zarr/action_chunk_transforms.py index ea0ce6ad..d0a10855 100644 --- a/egomimic/rldb/zarr/action_chunk_transforms.py +++ b/egomimic/rldb/zarr/action_chunk_transforms.py @@ -13,19 +13,24 @@ from __future__ import annotations from abc import abstractmethod +from typing import Literal import numpy as np +import torch from projectaria_tools.core.sophus import SE3 from scipy.spatial.transform import Rotation as R -import torch from egomimic.utils.pose_utils import ( _interpolate_euler, _interpolate_linear, _interpolate_quat_wxyz, + _interpolate_xyz, + _matrix_to_xyz, _matrix_to_xyzwxyz, _matrix_to_xyzypr, + _xyz_to_matrix, _xyzwxyz_to_matrix, + _xyzypr_to_matrix, ) # --------------------------------------------------------------------------- @@ -56,7 +61,7 @@ def __init__( action_key: str, output_action_key: str, stride: int = 1, - is_quat: bool = False, + mode: Literal["xyzwxyz", "xyzypr"] = "xyzwxyz", ): if stride <= 0: raise ValueError(f"stride must be positive, got {stride}") @@ -64,12 +69,12 @@ def __init__( self.action_key = action_key self.output_action_key = output_action_key self.stride = int(stride) - self.is_quat = is_quat + self.mode = mode def transform(self, batch: dict) -> dict: actions = np.asarray(batch[self.action_key]) actions = actions[:: self.stride] - if self.is_quat: + if self.mode == "xyzwxyz": if actions.ndim != 2 or actions.shape[-1] != 7: raise ValueError( f"InterpolatePose expects (T, 7) when is_quat=True, got " @@ -78,7 +83,7 @@ def transform(self, batch: dict) -> dict: batch[self.output_action_key] = _interpolate_quat_wxyz( actions, self.new_chunk_length ) - else: + elif self.mode == "xyzypr": if actions.ndim != 2 or actions.shape[-1] != 6: raise ValueError( f"InterpolatePose expects (T, 6), got {actions.shape} for key " @@ -87,6 +92,15 @@ def transform(self, batch: dict) -> dict: batch[self.output_action_key] = _interpolate_euler( actions, self.new_chunk_length ) + else: + if actions.shape[-1] != 3: + raise ValueError( + f"InterpolatePose expects (T, 3) or (T, K, 3), got {actions.shape} for key " + f"'{self.action_key}'" + ) + batch[self.output_action_key] = _interpolate_xyz( + actions, self.new_chunk_length + ) return batch @@ -122,29 +136,24 @@ def transform(self, batch: dict) -> dict: # --------------------------------------------------------------------------- -# Coordinate Transforms +# Reshape Transforms # --------------------------------------------------------------------------- -def _xyzypr_to_matrix(xyzypr: np.ndarray) -> np.ndarray: - """ - args: - xyzypr: (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] - returns: - (B, 4, 4) array of SE3 transformation matrices - """ - if xyzypr.ndim != 2 or xyzypr.shape[-1] != 6: - raise ValueError(f"Expected (B, 6) array, got shape {xyzypr.shape}") +class Reshape(Transform): + def __init__(self, input_key: str, output_key: str, shape: tuple): + self.input_key = input_key + self.output_key = output_key + self.shape = shape - B = xyzypr.shape[0] - dtype = xyzypr.dtype if np.issubdtype(xyzypr.dtype, np.floating) else np.float64 + def transform(self, batch: dict) -> dict: + batch[self.output_key] = batch[self.input_key].reshape(*self.shape) + return batch - mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() - # Input is [yaw, pitch, roll], so use ZYX order (Rz @ Ry @ Rx). - mats[:, :3, :3] = R.from_euler("ZYX", xyzypr[:, 3:6], degrees=False).as_matrix() - mats[:, :3, 3] = xyzypr[:, :3] - return mats +# --------------------------------------------------------------------------- +# Coordinate Transforms +# --------------------------------------------------------------------------- class ActionChunkCoordinateFrameTransform(Transform): @@ -154,7 +163,7 @@ def __init__( chunk_world: str, transformed_key_name: str, extra_batch_key: dict = None, - is_quat: bool = False, + mode: Literal["xyz", "xyzwxyz", "xyzypr"] = "xyzwxyz", ): """ args: @@ -167,7 +176,7 @@ def __init__( self.chunk_world = chunk_world self.transformed_key_name = transformed_key_name self.extra_batch_key = extra_batch_key - self.is_quat = is_quat + self.mode = mode def transform(self, batch): """ @@ -183,13 +192,34 @@ def transform(self, batch): if is_quat=False: (T, 6) xyz + ypr if is_quat=True: (T, 7) xyz + quat(wxyz) """ + # flatten to (T, D) + # target world is head pose, chunk world is keypoints batch.update(self.extra_batch_key or {}) target_world = np.asarray(batch[self.target_world]) chunk_world = np.asarray(batch[self.chunk_world]) - to_matrix_fn = _xyzwxyz_to_matrix if self.is_quat else _xyzypr_to_matrix + chunk_world_shape = None + + if chunk_world.ndim > 2: + chunk_world_shape = chunk_world.shape + chunk_world = chunk_world.reshape(-1, chunk_world_shape[-1]) + + to_matrix_fn = None + if self.mode == "xyzwxyz": + to_matrix_fn = _xyzwxyz_to_matrix + elif self.mode == "xyzypr": + to_matrix_fn = _xyzypr_to_matrix + elif self.mode == "xyz": + to_matrix_fn = _xyz_to_matrix + else: + raise ValueError(f"Invalid mode: {self.mode}") + target_world_to_matrix_fn = ( + _xyzwxyz_to_matrix if target_world.shape[-1] == 7 else _xyzypr_to_matrix + ) # Convert to SE3 for transformation - target_se3 = SE3.from_matrix(to_matrix_fn(target_world[None, :])[0]) # (4, 4) + target_se3 = SE3.from_matrix( + target_world_to_matrix_fn(target_world[None, :])[0] + ) # (4, 4) chunk_se3 = SE3.from_matrix(to_matrix_fn(chunk_world)) # (T, 4, 4) # Compute relative transform and apply to chunk @@ -197,11 +227,18 @@ def transform(self, batch): chunk_mats = chunk_in_target_frame.to_matrix() if chunk_mats.ndim == 2: chunk_mats = chunk_mats[None, ...] - chunk_in_target_frame = ( - _matrix_to_xyzwxyz(chunk_mats) - if self.is_quat - else _matrix_to_xyzypr(chunk_mats) - ) + + if self.mode == "xyzwxyz": + chunk_in_target_frame = _matrix_to_xyzwxyz(chunk_mats) + elif self.mode == "xyzypr": + chunk_in_target_frame = _matrix_to_xyzypr(chunk_mats) + elif self.mode == "xyz": + chunk_in_target_frame = _matrix_to_xyz(chunk_mats) + else: + raise ValueError(f"Invalid mode: {self.mode}") + + if chunk_world_shape is not None: + chunk_in_target_frame = chunk_in_target_frame.reshape(*chunk_world_shape) # Store transformed chunk back in batch batch[self.transformed_key_name] = chunk_in_target_frame @@ -237,27 +274,21 @@ def __init__( target_world: str, pose_world: str, transformed_key_name: str, - is_quat: bool = False, + mode: Literal["xyzwxyz", "xyzypr", "xyz"] = "xyzwxyz", ): self.target_world = target_world self.pose_world = pose_world self.transformed_key_name = transformed_key_name - self.is_quat = is_quat + self.mode = mode self._chunk_transform = ActionChunkCoordinateFrameTransform( target_world=target_world, chunk_world=pose_world, transformed_key_name=transformed_key_name, - is_quat=is_quat, + mode=mode, ) def transform(self, batch: dict) -> dict: pose_world = np.asarray(batch[self.pose_world]) - expected_shape = (7,) if self.is_quat else (6,) - if pose_world.shape != expected_shape: - raise ValueError( - f"Expected pose_world shape {expected_shape}, got {pose_world.shape}" - ) - transformed = self._chunk_transform.transform( { self.target_world: batch[self.target_world], @@ -414,10 +445,12 @@ def transform(self, batch): return batch + # --------------------------------------------------------------------------- # Type Transforms # --------------------------------------------------------------------------- + class NumpyToTensor(Transform): def __init__(self, keys: list[str]): self.keys = keys @@ -429,5 +462,7 @@ def transform(self, batch: dict) -> dict: elif isinstance(batch[key], torch.Tensor): batch[key] = batch[key].clone() else: - raise ValueError(f"NumpyToTensor expects key '{key}' to be a numpy array or torch tensor, got {type(batch[key])}") + raise ValueError( + f"NumpyToTensor expects key '{key}' to be a numpy array or torch tensor, got {type(batch[key])}" + ) return batch diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 1b4445ed..7d796441 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -716,12 +716,11 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if self.transform: for transform in self.transform or []: try: + # breakpoint() data = transform.transform(data) except Exception as e: logger.error(f"Error transforming data: {e}") - logger.error(f"Data: {data}") logger.error(f"Transform: {transform}") - logger.error(f"Error: {e}") if idx == 0: logger.error("Error in first frame") raise e diff --git a/egomimic/scripts/zarr_data_viz.ipynb b/egomimic/scripts/zarr_data_viz.ipynb new file mode 100644 index 00000000..1a1e4302 --- /dev/null +++ b/egomimic/scripts/zarr_data_viz.ipynb @@ -0,0 +1,420 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "79d184b3", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "29aeeb40", + "metadata": {}, + "source": [ + "# Eva Data\n", + "\n", + "This notebook builds a `MultiDataset` containing exactly one `ZarrDataset`, loads one batch, visualizes one image with `mediapy`, and prints the rest of the batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32d9110f", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import cv2\n", + "import imageio_ffmpeg\n", + "import mediapy as mpy\n", + "import numpy as np\n", + "import torch\n", + "\n", + "from egomimic.rldb.embodiment.eva import Eva\n", + "from egomimic.rldb.embodiment.human import Aria\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import S3EpisodeResolver\n", + "from egomimic.utils.egomimicUtils import (\n", + " INTRINSICS,\n", + " cam_frame_to_cam_pixels,\n", + " nds,\n", + ")\n", + "from egomimic.utils.aws.aws_data_utils import load_env\n", + "\n", + "# Ensure mediapy can find an ffmpeg executable in this environment\n", + "mpy.set_ffmpeg(imageio_ffmpeg.get_ffmpeg_exe())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc9edba1", + "metadata": {}, + "outputs": [], + "source": [ + "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n", + "load_env()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4aa1a05", + "metadata": {}, + "outputs": [], + "source": [ + "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n", + "# EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/1767495035712.zarr\")\n", + "\n", + "key_map = Eva.get_keymap()\n", + "transform_list = Eva.get_transform_list()\n", + "\n", + "# Build a MultiDataset with exactly one ZarrDataset inside\n", + "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", + "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n", + "\n", + "# multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n", + "resolver = S3EpisodeResolver(\n", + " TEMP_DIR, key_map=key_map, transform_list=transform_list\n", + ")\n", + "filters = {\n", + " \"episode_hash\": \"2025-12-26-18-07-46-296000\"\n", + "}\n", + "multi_ds = MultiDataset._from_resolver(\n", + " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", + ")\n", + "\n", + "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b72f3bb", + "metadata": {}, + "outputs": [], + "source": [ + "# Separate YPR visualization preview\n", + "for batch in loader:\n", + " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"axes\")\n", + " mpy.show_image(vis_ypr)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d8c3da2", + "metadata": {}, + "outputs": [], + "source": [ + "images = []\n", + "for i, batch in enumerate(loader):\n", + " vis = Eva.viz_transformed_batch(batch, mode=\"traj\")\n", + " images.append(vis)\n", + " if i > 10:\n", + " break\n", + "\n", + "mpy.show_video(images, fps=30)" + ] + }, + { + "cell_type": "markdown", + "id": "1a3382f1", + "metadata": {}, + "source": [ + "## Human Datasets\n", + "Mecka, Scale and Aria should all run exactly the same" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7384468", + "metadata": {}, + "outputs": [], + "source": [ + "temp_dir = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n", + "\n", + "intrinsics_key = \"base\"\n", + "\n", + "key_map = Aria.get_keymap(mode=\"keypoints\")\n", + "transform_list = Aria.get_transform_list(mode=\"keypoints\")\n", + "\n", + "resolver = S3EpisodeResolver(\n", + " temp_dir,\n", + " key_map=key_map,\n", + " transform_list=transform_list,\n", + ")\n", + "\n", + "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"} #aria\n", + "# filters = {\"episode_hash\": \"692ee048ef7557106e6c4b8d\"} # mecka\n", + "\n", + "cloudflare_ds = MultiDataset._from_resolver(\n", + " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", + ")\n", + "\n", + "loader = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af65095a", + "metadata": {}, + "outputs": [], + "source": [ + "ims = []\n", + "for i, batch in enumerate(loader):\n", + " vis = Aria.viz_transformed_batch(batch, mode=\"traj\")\n", + " ims.append(vis)\n", + " if i > 10:\n", + " break\n", + "\n", + "mpy.show_video(ims, fps=30)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6d8d872", + "metadata": {}, + "outputs": [], + "source": [ + "# Aria YPR video (same data loop, YPR overlay)\n", + "ims_ypr = []\n", + "for i, batch in enumerate(loader):\n", + " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"axes\")\n", + " ims_ypr.append(vis_ypr)\n", + " if i > 20:\n", + " break\n", + "\n", + "mpy.show_video(ims_ypr, fps=30)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60723adf", + "metadata": {}, + "outputs": [], + "source": [ + "ims_keypoints = []\n", + "for i, batch in enumerate(loader):\n", + " vis_keypoints = Aria.viz_transformed_batch(batch, mode=\"keypoints\")\n", + " ims_keypoints.append(vis_keypoints)\n", + " if i > 360:\n", + " break\n", + "\n", + "mpy.show_video(ims_keypoints, fps=20)" + ] + }, + { + "cell_type": "markdown", + "id": "efecaba7", + "metadata": {}, + "source": [ + "## Keypoint Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e39bca03", + "metadata": {}, + "outputs": [], + "source": [ + "# Load Scale episode with raw keypoints (no action chunking needed)\n", + "\n", + "from egomimic.rldb.zarr.action_chunk_transforms import _xyzwxyz_to_matrix\n", + "\n", + "key_map_kp = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"left.obs_keypoints\": {\"zarr_key\": \"left.obs_keypoints\"},\n", + " \"right.obs_keypoints\": {\"zarr_key\": \"right.obs_keypoints\"},\n", + " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", + "}\n", + "\n", + "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"}\n", + "\n", + "resolver = S3EpisodeResolver(\n", + " temp_dir,\n", + " key_map=key_map\n", + ")\n", + "\n", + "cloudflare_ds = MultiDataset._from_resolver(\n", + " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", + ")\n", + "\n", + "loader_kp = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "848c6d74", + "metadata": {}, + "outputs": [], + "source": [ + "# ARIA Keypoint Viz\n", + "# MANO skeleton edges: (parent, child) for drawing bones\n", + "MANO_EDGES = [\n", + " (0, 1), (1, 2), (2, 3), (3, 4), # thumb\n", + " (0, 5), (5, 6), (6, 7), (7, 8), # index\n", + " (0, 9), (9, 10), (10, 11), (11, 12), # middle\n", + " (0, 13), (13, 14), (14, 15), (15, 16), # ring\n", + " (0, 17), (17, 18), (18, 19), (19, 20), # pinky\n", + "]\n", + "\n", + "# aria configuration\n", + "MANO_EDGES = [\n", + " (5, 6,), (6, 7), (7, 0), # thumb\n", + " (5, 8), (8, 9), (9, 10), (9, 1), # index\n", + " (5, 11), (11, 12), (12, 13), (13, 2), # middle\n", + " (5, 14), (14, 15), (15, 16), (16, 3), # ring\n", + " (5, 17), (17, 18), (18, 19), (19, 4), # pinky\n", + "]\n", + "\n", + "FINGER_COLORS = {\n", + " \"thumb\": (255, 100, 100), # red\n", + " \"index\": (100, 255, 100), # green\n", + " \"middle\": (100, 100, 255), # blue\n", + " \"ring\": (255, 255, 100), # yellow\n", + " \"pinky\": (255, 100, 255), # magenta\n", + "}\n", + "FINGER_EDGE_RANGES = [\n", + " (\"thumb\", 0, 3), (\"index\", 3, 6), (\"middle\", 6, 9),\n", + " (\"ring\", 9, 12), (\"pinky\", 12, 15),\n", + "]\n", + "\n", + "\n", + "def viz_keypoints(batch, image_key=\"observations.images.front_img_1\"):\n", + " \"\"\"Visualize all 21 MANO keypoints per hand, projected onto the image.\"\"\"\n", + " # Prepare image\n", + " img = batch[image_key][0].detach().cpu()\n", + " if img.shape[0] in (1, 3):\n", + " img = img.permute(1, 2, 0)\n", + " img_np = img.numpy()\n", + " if img_np.dtype != np.uint8:\n", + " if img_np.max() <= 1.0:\n", + " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", + " else:\n", + " img_np = img_np.clip(0, 255).astype(np.uint8)\n", + " if img_np.shape[-1] == 1:\n", + " img_np = np.repeat(img_np, 3, axis=-1)\n", + "\n", + " intrinsics = INTRINSICS[\"base\"]\n", + " head_pose = batch[\"obs_head_pose\"][0].detach().cpu().numpy() # (6,)\n", + "\n", + " # T_head_world: camera pose in world (camera-to-world)\n", + " # We need world-to-camera = inv(T_head_world)\n", + " T_head_world = _xyzwxyz_to_matrix(head_pose[None, :])[0] # (4, 4)\n", + " T_world_to_cam = np.linalg.inv(T_head_world)\n", + "\n", + " vis = img_np.copy()\n", + " h, w = vis.shape[:2]\n", + "\n", + " for hand, dot_color in [(\"left\", (0, 120, 255)), (\"right\", (255, 80, 0))]:\n", + " kps_key = f\"{hand}.obs_keypoints\"\n", + " if kps_key not in batch:\n", + " continue\n", + " kps_flat = batch[kps_key][0].detach().cpu().numpy() # (63,)\n", + " kps_world = kps_flat.reshape(21, 3)\n", + "\n", + " # Skip if keypoints are all zero (invalid, clamped from 1e9)\n", + " if np.allclose(kps_world, 0.0, atol=1e-3):\n", + " continue\n", + "\n", + " # World -> camera frame\n", + " kps_h = np.concatenate([kps_world, np.ones((21, 1))], axis=1) # (21, 4)\n", + " kps_cam = (T_world_to_cam @ kps_h.T).T[:, :3] # (21, 3)\n", + "\n", + " # Camera frame -> pixels\n", + " kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (21, 3+)\n", + "\n", + " # Identify valid keypoints (z > 0 and in image bounds)\n", + " valid = (kps_cam[:, 2] > 0.01)\n", + " valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w)\n", + " valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h)\n", + "\n", + " # Draw skeleton edges (colored by finger)\n", + " for finger, start, end in FINGER_EDGE_RANGES:\n", + " color = FINGER_COLORS[finger]\n", + " for edge_idx in range(start, end):\n", + " i, j = MANO_EDGES[edge_idx]\n", + " if valid[i] and valid[j]:\n", + " p1 = (int(kps_px[i, 0]), int(kps_px[i, 1]))\n", + " p2 = (int(kps_px[j, 0]), int(kps_px[j, 1]))\n", + " cv2.line(vis, p1, p2, color, 2)\n", + "\n", + " # Draw keypoint dots on top\n", + " for k in range(21):\n", + " if valid[k]:\n", + " center = (int(kps_px[k, 0]), int(kps_px[k, 1]))\n", + " cv2.circle(vis, center, 4, dot_color, -1)\n", + " cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border\n", + "\n", + " # Label wrist\n", + " if valid[0]:\n", + " wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6)\n", + " cv2.putText(vis, f\"{hand[0].upper()}\", wrist_px,\n", + " cv2.FONT_HERSHEY_SIMPLEX, 0.5, dot_color, 2)\n", + "\n", + " return vis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75dbfa95", + "metadata": {}, + "outputs": [], + "source": [ + "# Render keypoint video\n", + "ims_kp = []\n", + "for i, batch_kp in enumerate(loader_kp):\n", + " vis = viz_keypoints(batch_kp)\n", + " ims_kp.append(vis)\n", + " if i > 10:\n", + " break\n", + "\n", + "mpy.show_video(ims_kp, fps=30)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f4fbaec", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/egomimic/utils/pose_utils.py b/egomimic/utils/pose_utils.py index 5edf9ffe..967a0a59 100644 --- a/egomimic/utils/pose_utils.py +++ b/egomimic/utils/pose_utils.py @@ -80,6 +80,14 @@ def _interpolate_quat_wxyz(seq: np.ndarray, chunk_length: int) -> np.ndarray: ) +def _interpolate_xyz(seq: np.ndarray, chunk_length: int) -> np.ndarray: + """Linear interpolation for arbitrary (T, 3) arrays or (T, K, 3) arrays.""" + T = seq.shape[0] + old_time = np.linspace(0, 1, T) + new_time = np.linspace(0, 1, chunk_length) + return interp1d(old_time, seq, axis=0, kind="linear")(new_time) + + def _matrix_to_xyzypr(mats: np.ndarray) -> np.ndarray: """ args: @@ -99,6 +107,24 @@ def _matrix_to_xyzypr(mats: np.ndarray) -> np.ndarray: return np.concatenate([xyz, ypr], axis=-1).astype(dtype, copy=False) +def _xyzypr_to_matrix(xyzypr: np.ndarray) -> np.ndarray: + """ + args: + xyzypr: (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyzypr.ndim != 2 or xyzypr.shape[-1] != 6: + raise ValueError(f"Expected (B, 6) array, got shape {xyzypr.shape}") + B = xyzypr.shape[0] + dtype = xyzypr.dtype if np.issubdtype(xyzypr.dtype, np.floating) else np.float64 + + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + mats[:, :3, :3] = R.from_euler("ZYX", xyzypr[:, 3:6], degrees=False).as_matrix() + mats[:, :3, 3] = xyzypr[:, :3] + return mats + + def _matrix_to_xyzwxyz(mats: np.ndarray) -> np.ndarray: """ args: @@ -139,3 +165,61 @@ def _xyzwxyz_to_matrix(xyzwxyz: np.ndarray) -> np.ndarray: mats[:, :3, 3] = xyzwxyz[:, :3] return mats + + +def _xyz_to_matrix(xyz: np.ndarray) -> np.ndarray: + """ + args: + xyz: (B, 3) np.array of [[x, y, z]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyz.ndim != 2 or xyz.shape[-1] != 3: + raise ValueError(f"Expected (B, 3) array, got shape {xyz.shape}") + B = xyz.shape[0] + dtype = xyz.dtype if np.issubdtype(xyz.dtype, np.floating) else np.float64 + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + mats[:, :3, 3] = xyz + return mats + + +def _matrix_to_xyz(mats: np.ndarray) -> np.ndarray: + """ + args: + mats: (B, 4, 4) array of SE3 transformation matrices + returns: + (B, 3) np.array of [[x, y, z]] + """ + if mats.ndim != 3 or mats.shape[-2:] != (4, 4): + raise ValueError(f"Expected (B, 4, 4) array, got shape {mats.shape}") + mats = np.asarray(mats) + dtype = mats.dtype if np.issubdtype(mats.dtype, np.floating) else np.float64 + return mats[:, :3, 3].astype(dtype, copy=False) + + +def _split_action_pose(actions): + # 14D layout: [L xyz ypr g, R xyz ypr g] + # 12D layout: [L xyz ypr, R xyz ypr] + if actions.shape[-1] == 14: + left_xyz = actions[..., :3] + left_ypr = actions[..., 3:6] + right_xyz = actions[..., 7:10] + right_ypr = actions[..., 10:13] + elif actions.shape[-1] == 12: + left_xyz = actions[..., :3] + left_ypr = actions[..., 3:6] + right_xyz = actions[..., 6:9] + right_ypr = actions[..., 9:12] + else: + raise ValueError(f"Unsupported action dim {actions.shape[-1]}") + return left_xyz, left_ypr, right_xyz, right_ypr + + +def _split_keypoints(keypoints): + left_xyz = keypoints[..., :3] + left_wxyz = keypoints[..., 3:7] + left_keypoints = keypoints[..., 7:70] + right_xyz = keypoints[..., 70:73] + right_wxyz = keypoints[..., 73:77] + right_keypoints = keypoints[..., 77:140] + return left_xyz, left_wxyz, left_keypoints, right_xyz, right_wxyz, right_keypoints diff --git a/egomimic/utils/type_utils.py b/egomimic/utils/type_utils.py new file mode 100644 index 00000000..e6608ecb --- /dev/null +++ b/egomimic/utils/type_utils.py @@ -0,0 +1,11 @@ +import numpy as np + + +def _to_numpy(arr): + if hasattr(arr, "detach"): + arr = arr.detach() + if hasattr(arr, "cpu"): + arr = arr.cpu() + if hasattr(arr, "numpy"): + return arr.numpy() + return np.asarray(arr) diff --git a/egomimic/utils/viz_utils.py b/egomimic/utils/viz_utils.py new file mode 100644 index 00000000..b6ecc408 --- /dev/null +++ b/egomimic/utils/viz_utils.py @@ -0,0 +1,200 @@ +import cv2 +import numpy as np +from scipy.spatial.transform import Rotation as R + +from egomimic.utils.egomimicUtils import ( + INTRINSICS, + cam_frame_to_cam_pixels, + draw_actions, +) +from egomimic.utils.pose_utils import _split_action_pose, _split_keypoints + + +def _prepare_viz_image(img): + if img.ndim == 3 and img.shape[0] in (1, 3): + img = np.transpose(img, (1, 2, 0)) + + if img.dtype != np.uint8: + if img.max() <= 1.0: + img = (img * 255.0).clip(0, 255).astype(np.uint8) + else: + img = img.clip(0, 255).astype(np.uint8) + + if img.ndim == 2: + img = np.repeat(img[:, :, None], 3, axis=-1) + elif img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + + return img + + +def _viz_traj(images, actions, intrinsics_key): + images = _prepare_viz_image(images) + intrinsics = INTRINSICS[intrinsics_key] + left_xyz, _, right_xyz, _ = _split_action_pose(actions) + + vis = draw_actions( + images.copy(), + type="xyz", + color="Blues", + actions=left_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="left", + ) + vis = draw_actions( + vis, + type="xyz", + color="Reds", + actions=right_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="right", + ) + return vis + + +def _viz_axes(images, actions, intrinsics_key, axis_len_m=0.04): + images = _prepare_viz_image(images) + intrinsics = INTRINSICS[intrinsics_key] + left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions) + vis = images.copy() + + def _draw_axis_color_legend(frame): + _, w = frame.shape[:2] + x_right = w - 12 + y_start = 14 + y_step = 12 + line_len = 24 + axis_legend = [ + ("x", (255, 0, 0)), + ("y", (0, 255, 0)), + ("z", (0, 0, 255)), + ] + for i, (name, color) in enumerate(axis_legend): + y = y_start + i * y_step + x0 = x_right - line_len + x1 = x_right + cv2.line(frame, (x0, y), (x1, y), color, 3) + cv2.putText( + frame, + name, + (x0 - 12, y + 4), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + color, + 1, + cv2.LINE_AA, + ) + return frame + + def _draw_rotation_at_anchor(frame, xyz_seq, ypr_seq, label, anchor_color): + if len(xyz_seq) == 0 or len(ypr_seq) == 0: + return frame + + palm_xyz = xyz_seq[0] + palm_ypr = ypr_seq[0] + rot = R.from_euler("ZYX", palm_ypr, degrees=False).as_matrix() + + axis_points_cam = np.vstack( + [ + palm_xyz, + palm_xyz + rot[:, 0] * axis_len_m, + palm_xyz + rot[:, 1] * axis_len_m, + palm_xyz + rot[:, 2] * axis_len_m, + ] + ) + + px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2] + if not np.isfinite(px).all(): + return frame + pts = np.round(px).astype(np.int32) + + h, w = frame.shape[:2] + x0, y0 = pts[0] + if not (0 <= x0 < w and 0 <= y0 < h): + return frame + + cv2.circle(frame, (x0, y0), 4, anchor_color, -1) + axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] + for i, color in enumerate(axis_colors, start=1): + x1, y1 = pts[i] + if 0 <= x1 < w and 0 <= y1 < h: + cv2.line(frame, (x0, y0), (x1, y1), color, 2) + cv2.circle(frame, (x1, y1), 2, color, -1) + + cv2.putText( + frame, + label, + (x0 + 6, max(12, y0 - 8)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + anchor_color, + 1, + cv2.LINE_AA, + ) + return frame + + vis = _draw_rotation_at_anchor(vis, left_xyz, left_ypr, "L rot", (255, 180, 80)) + vis = _draw_rotation_at_anchor(vis, right_xyz, right_ypr, "R rot", (80, 180, 255)) + vis = _draw_axis_color_legend(vis) + return vis + + +def _viz_keypoints(images, actions, intrinsics_key, edges, colors, edge_ranges): + """Visualize all 21 MANO keypoints per hand, projected onto the image.""" + # Prepare image + images = _prepare_viz_image(images) + + intrinsics = INTRINSICS[intrinsics_key] + + vis = images.copy() + h, w = vis.shape[:2] + + left_xyz, left_wxyz, left_keypoints, right_xyz, right_wxyz, right_keypoints = ( + _split_keypoints(actions) + ) + keypoints = {} + keypoints["left"] = left_keypoints.reshape(-1, 3) + keypoints["right"] = right_keypoints.reshape(-1, 3) + for hand, dot_color in [("left", (0, 120, 255)), ("right", (255, 80, 0))]: + kps_cam = keypoints[hand] + # Camera frame -> pixels + kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (42, 3+) 21 per arm + + # Identify valid keypoints (z > 0 and in image bounds) + valid = kps_cam[:, 2] > 0.01 + valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w) + valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h) + + # Draw skeleton edges (colored by finger) + for finger, start, end in edge_ranges: + color = colors[finger] + for edge_idx in range(start, end): + i, j = edges[edge_idx] + if valid[i] and valid[j]: + p1 = (int(kps_px[i, 0]), int(kps_px[i, 1])) + p2 = (int(kps_px[j, 0]), int(kps_px[j, 1])) + cv2.line(vis, p1, p2, color, 2) + + # Draw keypoint dots on top + for k in range(21): + if valid[k]: + center = (int(kps_px[k, 0]), int(kps_px[k, 1])) + cv2.circle(vis, center, 4, dot_color, -1) + cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border + + # Label wrist + if valid[0]: + wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6) + cv2.putText( + vis, + f"{hand[0].upper()}", + wrist_px, + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + dot_color, + 2, + ) + + return vis