diff --git a/egomimic/rldb/embodiment/embodiment.py b/egomimic/rldb/embodiment/embodiment.py index 2fc93b32..24fab039 100644 --- a/egomimic/rldb/embodiment/embodiment.py +++ b/egomimic/rldb/embodiment/embodiment.py @@ -1,5 +1,8 @@ +from abc import ABC from enum import Enum +from egomimic.rldb.zarr.action_chunk_transforms import Transform + class EMBODIMENT(Enum): EVE_RIGHT_ARM = 0 @@ -31,3 +34,22 @@ def get_embodiment(index): def get_embodiment_id(embodiment_name): embodiment_name = embodiment_name.upper() return EMBODIMENT[embodiment_name].value + + +class Embodiment(ABC): + """Base embodiment class. An embodiment is responsible for defining the transform pipeline that converts between the raw data in the dataset and the canonical representation used by the model.""" + + @staticmethod + def get_transform_list() -> list[Transform]: + """Returns the list of transforms that convert between the raw data in the dataset and the canonical representation used by the model.""" + raise NotImplementedError + + @staticmethod + def viz_transformed_batch(batch): + """Visualizes a batch of transformed data.""" + raise NotImplementedError + + @staticmethod + def get_keymap(): + """Returns a dictionary mapping from the raw keys in the dataset to the canonical keys used by the model.""" + raise NotImplementedError diff --git a/egomimic/rldb/embodiment/eva.py b/egomimic/rldb/embodiment/eva.py new file mode 100644 index 00000000..e72b1790 --- /dev/null +++ b/egomimic/rldb/embodiment/eva.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +import cv2 +import numpy as np +from scipy.spatial.transform import Rotation as R + +from egomimic.rldb.embodiment.embodiment import Embodiment +from egomimic.rldb.zarr.action_chunk_transforms import ( + ActionChunkCoordinateFrameTransform, + ConcatKeys, + DeleteKeys, + InterpolateLinear, + InterpolatePose, + PoseCoordinateFrameTransform, + Transform, + XYZWXYZ_to_XYZYPR, +) +from egomimic.utils.egomimicUtils import ( + EXTRINSICS, + INTRINSICS, + cam_frame_to_cam_pixels, + draw_actions, +) +from egomimic.utils.pose_utils import ( + _matrix_to_xyzwxyz, +) + + +class Eva(Embodiment): + VIZ_IMAGE_KEY = "observations.images.front_img_1" + + @staticmethod + def get_transform_list() -> list[Transform]: + return _build_eva_bimanual_transform_list() + + @classmethod + def viz_transformed_batch(cls, batch, mode=""): + """ + Visualize one transformed EVA batch sample. + + Modes: + - palm_traj: draw left/right palm trajectories from actions_cartesian. + - palm_axes: draw local xyz axes at each palm anchor using ypr. + """ + image_key = cls.VIZ_IMAGE_KEY + action_key = "actions_cartesian" + intrinsics_key = "base" + mode = (mode or "palm_traj").lower() + + if mode == "palm_traj": + return _viz_batch_palm_traj( + batch=batch, + image_key=image_key, + action_key=action_key, + intrinsics_key=intrinsics_key, + ) + if mode == "palm_axes": + return _viz_batch_palm_axes( + batch=batch, + image_key=image_key, + action_key=action_key, + intrinsics_key=intrinsics_key, + ) + + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: " + f"('palm_traj', 'palm_axes', 'keypoints')." + ) + + @classmethod + def get_keymap(cls): + return { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "observations.images.right_wrist_img": { + "key_type": "camera_keys", + "zarr_key": "images.right_wrist", + }, + "observations.images.left_wrist_img": { + "key_type": "camera_keys", + "zarr_key": "images.left_wrist", + }, + "right.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_ee_pose", + }, + "right.obs_gripper": { + "key_type": "proprio_keys", + "zarr_key": "right.gripper", + }, + "left.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_ee_pose", + }, + "left.obs_gripper": { + "key_type": "proprio_keys", + "zarr_key": "left.gripper", + }, + "right.gripper": { + "key_type": "action_keys", + "zarr_key": "right.gripper", + "horizon": 45, + }, + "left.gripper": { + "key_type": "action_keys", + "zarr_key": "left.gripper", + "horizon": 45, + }, + "right.cmd_ee_pose": { + "key_type": "action_keys", + "zarr_key": "right.cmd_ee_pose", + "horizon": 45, + }, + "left.cmd_ee_pose": { + "key_type": "action_keys", + "zarr_key": "left.cmd_ee_pose", + "horizon": 45, + }, + } + + +def _build_eva_bimanual_transform_list( + *, + left_target_world: str = "left_extrinsics_pose", + right_target_world: str = "right_extrinsics_pose", + left_cmd_world: str = "left.cmd_ee_pose", + right_cmd_world: str = "right.cmd_ee_pose", + left_obs_pose: str = "left.obs_ee_pose", + right_obs_pose: str = "right.obs_ee_pose", + left_obs_gripper: str = "left.obs_gripper", + right_obs_gripper: str = "right.obs_gripper", + left_gripper: str = "left.gripper", + right_gripper: str = "right.gripper", + left_cmd_camframe: str = "left.cmd_ee_pose_camframe", + right_cmd_camframe: str = "right.cmd_ee_pose_camframe", + actions_key: str = "actions_cartesian", + obs_key: str = "observations.state.ee_pose", + chunk_length: int = 100, + stride: int = 1, + extrinsics_key: str = "x5Dec13_2", + is_quat: bool = True, +) -> list[Transform]: + """Canonical EVA bimanual transform pipeline used by tests and notebooks.""" + extrinsics = EXTRINSICS[extrinsics_key] + left_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["left"][None, :])[0] + right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["right"][None, :])[0] + left_extra_batch_key = {"left_extrinsics_pose": left_extrinsics_pose} + right_extra_batch_key = {"right_extrinsics_pose": right_extrinsics_pose} + transform_list = [ + ActionChunkCoordinateFrameTransform( + target_world=left_target_world, + chunk_world=left_cmd_world, + transformed_key_name=left_cmd_camframe, + extra_batch_key=left_extra_batch_key, + is_quat=is_quat, + ), + ActionChunkCoordinateFrameTransform( + target_world=right_target_world, + chunk_world=right_cmd_world, + transformed_key_name=right_cmd_camframe, + extra_batch_key=right_extra_batch_key, + is_quat=is_quat, + ), + PoseCoordinateFrameTransform( + target_world=left_target_world, + pose_world=left_obs_pose, + transformed_key_name=left_obs_pose, + is_quat=is_quat, + ), + PoseCoordinateFrameTransform( + target_world=right_target_world, + pose_world=right_obs_pose, + transformed_key_name=right_obs_pose, + is_quat=is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_cmd_camframe, + output_action_key=left_cmd_camframe, + stride=stride, + is_quat=is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_cmd_camframe, + output_action_key=right_cmd_camframe, + stride=stride, + is_quat=is_quat, + ), + InterpolateLinear( + new_chunk_length=chunk_length, + action_key=left_gripper, + output_action_key=left_gripper, + stride=stride, + ), + InterpolateLinear( + new_chunk_length=chunk_length, + action_key=right_gripper, + output_action_key=right_gripper, + stride=stride, + ), + ] + + if is_quat: + transform_list.append( + XYZWXYZ_to_XYZYPR( + keys=[ + left_cmd_camframe, + right_cmd_camframe, + left_obs_pose, + right_obs_pose, + ] + ) + ) + + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_cmd_camframe, + left_gripper, + right_cmd_camframe, + right_gripper, + ], + new_key_name=actions_key, + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_obs_pose, + left_obs_gripper, + right_obs_pose, + right_obs_gripper, + ], + new_key_name=obs_key, + delete_old_keys=True, + ), + DeleteKeys( + keys_to_delete=[ + left_cmd_world, + right_cmd_world, + left_target_world, + right_target_world, + ] + ), + ] + ) + return transform_list + + +def _to_numpy(arr): + if hasattr(arr, "detach"): + arr = arr.detach() + if hasattr(arr, "cpu"): + arr = arr.cpu() + if hasattr(arr, "numpy"): + return arr.numpy() + return np.asarray(arr) + + +def _split_action_pose(actions): + # 14D layout: [L xyz ypr g, R xyz ypr g] + # 12D layout: [L xyz ypr, R xyz ypr] + if actions.shape[-1] == 14: + left_xyz = actions[:, :3] + left_ypr = actions[:, 3:6] + right_xyz = actions[:, 7:10] + right_ypr = actions[:, 10:13] + elif actions.shape[-1] == 12: + left_xyz = actions[:, :3] + left_ypr = actions[:, 3:6] + right_xyz = actions[:, 6:9] + right_ypr = actions[:, 9:12] + else: + raise ValueError(f"Unsupported action dim {actions.shape[-1]}") + return left_xyz, left_ypr, right_xyz, right_ypr + + +def _prepare_viz_image(batch, image_key): + img = _to_numpy(batch[image_key][0]) + if img.ndim == 3 and img.shape[0] in (1, 3): + img = np.transpose(img, (1, 2, 0)) + + if img.dtype != np.uint8: + if img.max() <= 1.0: + img = (img * 255.0).clip(0, 255).astype(np.uint8) + else: + img = img.clip(0, 255).astype(np.uint8) + + if img.ndim == 2: + img = np.repeat(img[:, :, None], 3, axis=-1) + elif img.shape[-1] == 1: + img = np.repeat(img, 3, axis=-1) + + return img + + +def _viz_batch_palm_traj(batch, image_key, action_key, intrinsics_key): + img_np = _prepare_viz_image(batch, image_key) + intrinsics = INTRINSICS[intrinsics_key] + actions = _to_numpy(batch[action_key][0]) + left_xyz, _, right_xyz, _ = _split_action_pose(actions) + + vis = draw_actions( + img_np.copy(), + type="xyz", + color="Blues", + actions=left_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="left", + ) + vis = draw_actions( + vis, + type="xyz", + color="Reds", + actions=right_xyz, + extrinsics=None, + intrinsics=intrinsics, + arm="right", + ) + return vis + + +def _viz_batch_palm_axes(batch, image_key, action_key, intrinsics_key, axis_len_m=0.04): + img_np = _prepare_viz_image(batch, image_key) + intrinsics = INTRINSICS[intrinsics_key] + actions = _to_numpy(batch[action_key][0]) + left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions) + vis = img_np.copy() + + def _draw_axis_color_legend(frame): + _, w = frame.shape[:2] + x_right = w - 12 + y_start = 14 + y_step = 12 + line_len = 24 + axis_legend = [ + ("x", (255, 0, 0)), + ("y", (0, 255, 0)), + ("z", (0, 0, 255)), + ] + for i, (name, color) in enumerate(axis_legend): + y = y_start + i * y_step + x0 = x_right - line_len + x1 = x_right + cv2.line(frame, (x0, y), (x1, y), color, 3) + cv2.putText( + frame, + name, + (x0 - 12, y + 4), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + color, + 1, + cv2.LINE_AA, + ) + return frame + + def _draw_rotation_at_palm(frame, xyz_seq, ypr_seq, label, anchor_color): + if len(xyz_seq) == 0 or len(ypr_seq) == 0: + return frame + + palm_xyz = xyz_seq[0] + palm_ypr = ypr_seq[0] + rot = R.from_euler("ZYX", palm_ypr, degrees=False).as_matrix() + + axis_points_cam = np.vstack( + [ + palm_xyz, + palm_xyz + rot[:, 0] * axis_len_m, + palm_xyz + rot[:, 1] * axis_len_m, + palm_xyz + rot[:, 2] * axis_len_m, + ] + ) + + px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2] + if not np.isfinite(px).all(): + return frame + pts = np.round(px).astype(np.int32) + + h, w = frame.shape[:2] + x0, y0 = pts[0] + if not (0 <= x0 < w and 0 <= y0 < h): + return frame + + cv2.circle(frame, (x0, y0), 4, anchor_color, -1) + axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] + for i, color in enumerate(axis_colors, start=1): + x1, y1 = pts[i] + if 0 <= x1 < w and 0 <= y1 < h: + cv2.line(frame, (x0, y0), (x1, y1), color, 2) + cv2.circle(frame, (x1, y1), 2, color, -1) + + cv2.putText( + frame, + label, + (x0 + 6, max(12, y0 - 8)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + anchor_color, + 1, + cv2.LINE_AA, + ) + return frame + + vis = _draw_rotation_at_palm(vis, left_xyz, left_ypr, "L rot", (255, 180, 80)) + vis = _draw_rotation_at_palm(vis, right_xyz, right_ypr, "R rot", (80, 180, 255)) + vis = _draw_axis_color_legend(vis) + return vis diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py new file mode 100644 index 00000000..1939fd91 --- /dev/null +++ b/egomimic/rldb/embodiment/human.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from egomimic.rldb.embodiment.embodiment import Embodiment +from egomimic.rldb.embodiment.eva import ( + _viz_batch_palm_axes, + _viz_batch_palm_traj, +) +from egomimic.rldb.zarr.action_chunk_transforms import ( + ActionChunkCoordinateFrameTransform, + ConcatKeys, + DeleteKeys, + InterpolatePose, + PoseCoordinateFrameTransform, + Transform, + XYZWXYZ_to_XYZYPR, +) + + +class Human(Embodiment): + VIZ_INTRINSICS_KEY = "base" + VIZ_IMAGE_KEY = "observations.images.front_img_1" + ACTION_STRIDE = 3 + + @classmethod + def get_transform_list(cls) -> list[Transform]: + return _build_aria_bimanual_transform_list(stride=cls.ACTION_STRIDE) + + @classmethod + def viz_transformed_batch(cls, batch, mode=""): + image_key = cls.VIZ_IMAGE_KEY + action_key = "actions_cartesian" + intrinsics_key = cls.VIZ_INTRINSICS_KEY + mode = (mode or "palm_traj").lower() + + if mode == "palm_traj": + return _viz_batch_palm_traj( + batch=batch, + image_key=image_key, + action_key=action_key, + intrinsics_key=intrinsics_key, + ) + if mode == "palm_axes": + return _viz_batch_palm_axes( + batch=batch, + image_key=image_key, + action_key=action_key, + intrinsics_key=intrinsics_key, + ) + if mode == "keypoints": + raise NotImplementedError( + "mode='keypoints' is reserved and not implemented yet." + ) + + raise ValueError( + f"Unsupported mode '{mode}'. Expected one of: " + f"('palm_traj', 'palm_axes', 'keypoints')." + ) + + @classmethod + def get_keymap(cls): + return { + cls.VIZ_IMAGE_KEY: { + "key_type": "camera_keys", + "zarr_key": "images.front_1", + }, + "right.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "right.obs_ee_pose", + "horizon": 30, + }, + "left.action_ee_pose": { + "key_type": "action_keys", + "zarr_key": "left.obs_ee_pose", + "horizon": 30, + }, + "right.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "right.obs_ee_pose", + }, + "left.obs_ee_pose": { + "key_type": "proprio_keys", + "zarr_key": "left.obs_ee_pose", + }, + "obs_head_pose": { + "key_type": "proprio_keys", + "zarr_key": "obs_head_pose", + }, + } + + +class Aria(Human): + VIZ_INTRINSICS_KEY = "base" + ACTION_STRIDE = 3 + + +class Scale(Human): + VIZ_INTRINSICS_KEY = "scale" + ACTION_STRIDE = 1 + + +class Mecka(Human): + VIZ_INTRINSICS_KEY = "mecka" + ACTION_STRIDE = 1 + + +def _build_aria_bimanual_transform_list( + *, + target_world: str = "obs_head_pose", + target_world_ypr: str = "obs_head_pose_ypr", + target_world_is_quat: bool = True, + left_action_world: str = "left.action_ee_pose", + right_action_world: str = "right.action_ee_pose", + left_obs_pose: str = "left.obs_ee_pose", + right_obs_pose: str = "right.obs_ee_pose", + left_action_headframe: str = "left.action_ee_pose_headframe", + right_action_headframe: str = "right.action_ee_pose_headframe", + left_obs_headframe: str = "left.obs_ee_pose_headframe", + right_obs_headframe: str = "right.obs_ee_pose_headframe", + actions_key: str = "actions_cartesian", + obs_key: str = "observations.state.ee_pose", + chunk_length: int = 100, + stride: int = 3, + delete_target_world: bool = True, +) -> list[Transform]: + """Canonical ARIA bimanual transform pipeline used by tests and notebooks. + + Aria human data does not have commanded ee poses; action chunks are built + from stacked observed ee poses (typically with a horizon on + ``left/right.action_ee_pose`` mapped from ``left/right.obs_ee_pose``). + """ + keys_to_delete = list( + { + left_action_world, + right_action_world, + left_obs_pose, + right_obs_pose, + } + ) + target_pose_key = target_world + if delete_target_world: + keys_to_delete.append(target_world) + if target_world_is_quat: + keys_to_delete.append(target_world_ypr) + + transform_list: list[Transform] = [ + ActionChunkCoordinateFrameTransform( + target_world=target_pose_key, + chunk_world=left_action_world, + transformed_key_name=left_action_headframe, + is_quat=target_world_is_quat, + ), + ActionChunkCoordinateFrameTransform( + target_world=target_pose_key, + chunk_world=right_action_world, + transformed_key_name=right_action_headframe, + is_quat=target_world_is_quat, + ), + PoseCoordinateFrameTransform( + target_world=target_pose_key, + pose_world=left_obs_pose, + transformed_key_name=left_obs_headframe, + is_quat=target_world_is_quat, + ), + PoseCoordinateFrameTransform( + target_world=target_pose_key, + pose_world=right_obs_pose, + transformed_key_name=right_obs_headframe, + is_quat=target_world_is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_action_headframe, + output_action_key=left_action_headframe, + stride=stride, + is_quat=target_world_is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_action_headframe, + output_action_key=right_action_headframe, + stride=stride, + is_quat=target_world_is_quat, + ), + ] + + if target_world_is_quat: + transform_list.append( + XYZWXYZ_to_XYZYPR( + keys=[ + left_action_headframe, + right_action_headframe, + left_obs_headframe, + right_obs_headframe, + ] + ) + ) + + transform_list.extend( + [ + ConcatKeys( + key_list=[left_action_headframe, right_action_headframe], + new_key_name=actions_key, + delete_old_keys=True, + ), + ConcatKeys( + key_list=[left_obs_headframe, right_obs_headframe], + new_key_name=obs_key, + delete_old_keys=True, + ), + DeleteKeys(keys_to_delete=keys_to_delete), + ] + ) + return transform_list diff --git a/egomimic/rldb/zarr/action_chunk_transforms.py b/egomimic/rldb/zarr/action_chunk_transforms.py index 1946012f..3a332edc 100644 --- a/egomimic/rldb/zarr/action_chunk_transforms.py +++ b/egomimic/rldb/zarr/action_chunk_transforms.py @@ -18,7 +18,6 @@ from projectaria_tools.core.sophus import SE3 from scipy.spatial.transform import Rotation as R -from egomimic.utils.egomimicUtils import EXTRINSICS from egomimic.utils.pose_utils import ( _interpolate_euler, _interpolate_linear, @@ -413,247 +412,3 @@ def transform(self, batch): batch.pop(k, None) return batch - - -# --------------------------------------------------------------------------- -# Transform List Factories -# --------------------------------------------------------------------------- - - -def build_eva_bimanual_transform_list( - *, - left_target_world: str = "left_extrinsics_pose", - right_target_world: str = "right_extrinsics_pose", - left_cmd_world: str = "left.cmd_ee_pose", - right_cmd_world: str = "right.cmd_ee_pose", - left_obs_pose: str = "left.obs_ee_pose", - right_obs_pose: str = "right.obs_ee_pose", - left_obs_gripper: str = "left.obs_gripper", - right_obs_gripper: str = "right.obs_gripper", - left_gripper: str = "left.gripper", - right_gripper: str = "right.gripper", - left_cmd_camframe: str = "left.cmd_ee_pose_camframe", - right_cmd_camframe: str = "right.cmd_ee_pose_camframe", - actions_key: str = "actions_cartesian", - obs_key: str = "observations.state.ee_pose", - chunk_length: int = 100, - stride: int = 1, - extrinsics_key: str = "x5Dec13_2", - is_quat: bool = True, -) -> list[Transform]: - """Canonical EVA bimanual transform pipeline used by tests and notebooks.""" - extrinsics = EXTRINSICS[extrinsics_key] - left_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["left"][None, :])[0] - right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["right"][None, :])[0] - left_extra_batch_key = {"left_extrinsics_pose": left_extrinsics_pose} - right_extra_batch_key = {"right_extrinsics_pose": right_extrinsics_pose} - transform_list = [ - ActionChunkCoordinateFrameTransform( - target_world=left_target_world, - chunk_world=left_cmd_world, - transformed_key_name=left_cmd_camframe, - extra_batch_key=left_extra_batch_key, - is_quat=is_quat, - ), - ActionChunkCoordinateFrameTransform( - target_world=right_target_world, - chunk_world=right_cmd_world, - transformed_key_name=right_cmd_camframe, - extra_batch_key=right_extra_batch_key, - is_quat=is_quat, - ), - PoseCoordinateFrameTransform( - target_world=left_target_world, - pose_world=left_obs_pose, - transformed_key_name=left_obs_pose, - is_quat=is_quat, - ), - PoseCoordinateFrameTransform( - target_world=right_target_world, - pose_world=right_obs_pose, - transformed_key_name=right_obs_pose, - is_quat=is_quat, - ), - InterpolatePose( - new_chunk_length=chunk_length, - action_key=left_cmd_camframe, - output_action_key=left_cmd_camframe, - stride=stride, - is_quat=is_quat, - ), - InterpolatePose( - new_chunk_length=chunk_length, - action_key=right_cmd_camframe, - output_action_key=right_cmd_camframe, - stride=stride, - is_quat=is_quat, - ), - InterpolateLinear( - new_chunk_length=chunk_length, - action_key=left_gripper, - output_action_key=left_gripper, - stride=stride, - ), - InterpolateLinear( - new_chunk_length=chunk_length, - action_key=right_gripper, - output_action_key=right_gripper, - stride=stride, - ), - ] - - if is_quat: - transform_list.append( - XYZWXYZ_to_XYZYPR( - keys=[ - left_cmd_camframe, - right_cmd_camframe, - left_obs_pose, - right_obs_pose, - ] - ) - ) - - transform_list.extend( - [ - ConcatKeys( - key_list=[ - left_cmd_camframe, - left_gripper, - right_cmd_camframe, - right_gripper, - ], - new_key_name=actions_key, - delete_old_keys=True, - ), - ConcatKeys( - key_list=[ - left_obs_pose, - left_obs_gripper, - right_obs_pose, - right_obs_gripper, - ], - new_key_name=obs_key, - delete_old_keys=True, - ), - DeleteKeys( - keys_to_delete=[ - left_cmd_world, - right_cmd_world, - left_target_world, - right_target_world, - ] - ), - ] - ) - return transform_list - - -def build_aria_bimanual_transform_list( - *, - target_world: str = "obs_head_pose", - target_world_ypr: str = "obs_head_pose_ypr", - target_world_is_quat: bool = True, - left_action_world: str = "left.action_ee_pose", - right_action_world: str = "right.action_ee_pose", - left_obs_pose: str = "left.obs_ee_pose", - right_obs_pose: str = "right.obs_ee_pose", - left_action_headframe: str = "left.action_ee_pose_headframe", - right_action_headframe: str = "right.action_ee_pose_headframe", - left_obs_headframe: str = "left.obs_ee_pose_headframe", - right_obs_headframe: str = "right.obs_ee_pose_headframe", - actions_key: str = "actions_cartesian", - obs_key: str = "observations.state.ee_pose", - chunk_length: int = 100, - stride: int = 3, - delete_target_world: bool = True, -) -> list[Transform]: - """Canonical ARIA bimanual transform pipeline used by tests and notebooks. - - Aria human data does not have commanded ee poses; action chunks are built - from stacked observed ee poses (typically with a horizon on - ``left/right.action_ee_pose`` mapped from ``left/right.obs_ee_pose``). - """ - keys_to_delete = list( - { - left_action_world, - right_action_world, - left_obs_pose, - right_obs_pose, - } - ) - target_pose_key = target_world - if delete_target_world: - keys_to_delete.append(target_world) - if target_world_is_quat: - keys_to_delete.append(target_world_ypr) - - transform_list: list[Transform] = [ - ActionChunkCoordinateFrameTransform( - target_world=target_pose_key, - chunk_world=left_action_world, - transformed_key_name=left_action_headframe, - is_quat=target_world_is_quat, - ), - ActionChunkCoordinateFrameTransform( - target_world=target_pose_key, - chunk_world=right_action_world, - transformed_key_name=right_action_headframe, - is_quat=target_world_is_quat, - ), - PoseCoordinateFrameTransform( - target_world=target_pose_key, - pose_world=left_obs_pose, - transformed_key_name=left_obs_headframe, - is_quat=target_world_is_quat, - ), - PoseCoordinateFrameTransform( - target_world=target_pose_key, - pose_world=right_obs_pose, - transformed_key_name=right_obs_headframe, - is_quat=target_world_is_quat, - ), - InterpolatePose( - new_chunk_length=chunk_length, - action_key=left_action_headframe, - output_action_key=left_action_headframe, - stride=stride, - is_quat=target_world_is_quat, - ), - InterpolatePose( - new_chunk_length=chunk_length, - action_key=right_action_headframe, - output_action_key=right_action_headframe, - stride=stride, - is_quat=target_world_is_quat, - ), - ] - - if target_world_is_quat: - transform_list.append( - XYZWXYZ_to_XYZYPR( - keys=[ - left_action_headframe, - right_action_headframe, - left_obs_headframe, - right_obs_headframe, - ] - ) - ) - - transform_list.extend( - [ - ConcatKeys( - key_list=[left_action_headframe, right_action_headframe], - new_key_name=actions_key, - delete_old_keys=True, - ), - ConcatKeys( - key_list=[left_obs_headframe, right_obs_headframe], - new_key_name=obs_key, - delete_old_keys=True, - ), - DeleteKeys(keys_to_delete=keys_to_delete), - ] - ) - return transform_list diff --git a/egomimic/scripts/zarr_data_viz.ipynb b/egomimic/scripts/zarr_data_viz.ipynb index 0e075069..87a27406 100644 --- a/egomimic/scripts/zarr_data_viz.ipynb +++ b/egomimic/scripts/zarr_data_viz.ipynb @@ -24,23 +24,17 @@ "import mediapy as mpy\n", "import numpy as np\n", "import torch\n", - "from scipy.spatial.transform import Rotation as R\n", "\n", - "from egomimic.rldb.zarr.action_chunk_transforms import (\n", - " _matrix_to_xyzwxyz,\n", - " build_aria_bimanual_transform_list,\n", - " build_eva_bimanual_transform_list,\n", - ")\n", + "from egomimic.rldb.embodiment.eva import Eva\n", + "from egomimic.rldb.embodiment.human import Aria\n", "from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import S3EpisodeResolver\n", "from egomimic.utils.egomimicUtils import (\n", - " EXTRINSICS,\n", " INTRINSICS,\n", " cam_frame_to_cam_pixels,\n", - " draw_actions,\n", " nds,\n", ")\n", - "from egomimic.rldb.zarr.zarr_dataset_multi import S3EpisodeResolver\n", - "from egomimic.rldb.zarr.action_chunk_transforms import build_aria_bimanual_transform_list\n", + "from egomimic.utils.aws.aws_data_utils import load_env\n", "\n", "# Ensure mediapy can find an ffmpeg executable in this environment\n", "mpy.set_ffmpeg(imageio_ffmpeg.get_ffmpeg_exe())" @@ -53,7 +47,8 @@ "metadata": {}, "outputs": [], "source": [ - "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"" + "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n", + "load_env()" ] }, { @@ -66,34 +61,8 @@ "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n", "# EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/1767495035712.zarr\")\n", "\n", - "key_map = {\n", - " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", - " \"images.right_wrist\": {\"zarr_key\": \"images.right_wrist\"},\n", - " \"images.left_wrist\": {\"zarr_key\": \"images.left_wrist\"},\n", - " \"right.obs_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\"},\n", - " \"right.obs_gripper\": {\"zarr_key\": \"right.gripper\"},\n", - " \"left.obs_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\"},\n", - " \"left.obs_gripper\": {\"zarr_key\": \"left.gripper\"},\n", - " \"right.gripper\": {\"zarr_key\": \"right.gripper\", \"horizon\": 45},\n", - " \"left.gripper\": {\"zarr_key\": \"left.gripper\", \"horizon\": 45},\n", - " \"right.cmd_ee_pose\": {\"zarr_key\": \"right.cmd_ee_pose\", \"horizon\": 45},\n", - " \"left.cmd_ee_pose\": {\"zarr_key\": \"left.cmd_ee_pose\", \"horizon\": 45},\n", - "}\n", - "\n", - "ACTION_CHUNK_LENGTH = 100\n", - "ACTION_STRIDE = 1 # set to 3 for Aria-style anchor sampling\n", - "\n", - "extrinsics = EXTRINSICS[\"x5Dec13_2\"]\n", - "left_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics[\"left\"][None, :])[0]\n", - "right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics[\"right\"][None, :])[0]\n", - "\n", - "transform_list = build_eva_bimanual_transform_list(\n", - " chunk_length=ACTION_CHUNK_LENGTH,\n", - " stride=ACTION_STRIDE,\n", - " # left_extra_batch_key={\"left_extrinsics_pose\": left_extrinsics_pose},\n", - " # right_extra_batch_key={\"right_extrinsics_pose\": right_extrinsics_pose},\n", - ")\n", - "\n", + "key_map = Eva.get_keymap()\n", + "transform_list = Eva.get_transform_list()\n", "\n", "# Build a MultiDataset with exactly one ZarrDataset inside\n", "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", @@ -113,151 +82,6 @@ "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "0e4626d8", - "metadata": {}, - "outputs": [], - "source": [ - "def _split_action_pose(actions):\n", - " # 14D layout: [L xyz ypr g, R xyz ypr g]\n", - " # 12D layout: [L xyz ypr, R xyz ypr]\n", - " if actions.shape[-1] == 14:\n", - " left_xyz = actions[:, :3]\n", - " left_ypr = actions[:, 3:6]\n", - " right_xyz = actions[:, 7:10]\n", - " right_ypr = actions[:, 10:13]\n", - " elif actions.shape[-1] == 12:\n", - " left_xyz = actions[:, :3]\n", - " left_ypr = actions[:, 3:6]\n", - " right_xyz = actions[:, 6:9]\n", - " right_ypr = actions[:, 9:12]\n", - " else:\n", - " raise ValueError(f\"Unsupported action dim {actions.shape[-1]}\")\n", - " return left_xyz, left_ypr, right_xyz, right_ypr\n", - "\n", - "\n", - "def viz_batch(batch, image_key, action_key, intrinsics_key):\n", - " # Image: (C,H,W) -> (H,W,C)\n", - " img = batch[image_key][0].detach().cpu()\n", - " if img.shape[0] in (1, 3):\n", - " img = img.permute(1, 2, 0)\n", - " img_np = img.numpy()\n", - "\n", - " if img_np.dtype != np.uint8:\n", - " if img_np.max() <= 1.0:\n", - " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", - " else:\n", - " img_np = img_np.clip(0, 255).astype(np.uint8)\n", - " if img_np.shape[-1] == 1:\n", - " img_np = np.repeat(img_np, 3, axis=-1)\n", - "\n", - " intrinsics = INTRINSICS[intrinsics_key]\n", - " actions = batch[action_key][0].detach().cpu().numpy()\n", - " left_xyz, _, right_xyz, _ = _split_action_pose(actions)\n", - "\n", - " vis = draw_actions(\n", - " img_np.copy(), type=\"xyz\", color=\"Blues\",\n", - " actions=left_xyz, extrinsics=None, intrinsics=intrinsics, arm=\"left\"\n", - " )\n", - " vis = draw_actions(\n", - " vis, type=\"xyz\", color=\"Reds\",\n", - " actions=right_xyz, extrinsics=None, intrinsics=intrinsics, arm=\"right\"\n", - " )\n", - " return vis\n", - "\n", - "\n", - "def viz_batch_ypr(batch, image_key, action_key, intrinsics_key, axis_len_m=0.04):\n", - " img = batch[image_key][0].detach().cpu()\n", - " if img.shape[0] in (1, 3):\n", - " img = img.permute(1, 2, 0)\n", - " img_np = img.numpy()\n", - "\n", - " if img_np.dtype != np.uint8:\n", - " if img_np.max() <= 1.0:\n", - " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", - " else:\n", - " img_np = img_np.clip(0, 255).astype(np.uint8)\n", - " if img_np.shape[-1] == 1:\n", - " img_np = np.repeat(img_np, 3, axis=-1)\n", - "\n", - " intrinsics = INTRINSICS[intrinsics_key]\n", - " actions = batch[action_key][0].detach().cpu().numpy()\n", - " left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions)\n", - "\n", - " vis = img_np.copy()\n", - "\n", - " def _draw_axis_color_legend(frame):\n", - " h, w = frame.shape[:2]\n", - " x_right = w - 12\n", - " y_start = 14\n", - " y_step = 12\n", - " line_len = 24\n", - " axis_legend = [\n", - " (\"x\", (255, 0, 0)),\n", - " (\"y\", (0, 255, 0)),\n", - " (\"z\", (0, 0, 255)),\n", - " ]\n", - " for i, (name, color) in enumerate(axis_legend):\n", - " y = y_start + i * y_step\n", - " x0 = x_right - line_len\n", - " x1 = x_right\n", - " cv2.line(frame, (x0, y), (x1, y), color, 3)\n", - " cv2.putText(\n", - " frame, name, (x0 - 12, y + 4),\n", - " cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA,\n", - " )\n", - " return frame\n", - "\n", - " def _draw_rotation_at_palm(frame, xyz_seq, ypr_seq, label, anchor_color):\n", - " if len(xyz_seq) == 0 or len(ypr_seq) == 0:\n", - " return frame\n", - "\n", - " palm_xyz = xyz_seq[0]\n", - " palm_ypr = ypr_seq[0]\n", - " rot = R.from_euler(\"ZYX\", palm_ypr, degrees=False).as_matrix()\n", - " # print(np.linalg.det(rot))\n", - "\n", - " axis_points_cam = np.vstack([\n", - " palm_xyz,\n", - " palm_xyz + rot[:, 0] * axis_len_m,\n", - " palm_xyz + rot[:, 1] * axis_len_m,\n", - " palm_xyz + rot[:, 2] * axis_len_m,\n", - " ])\n", - "\n", - " px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2]\n", - " if not np.isfinite(px).all():\n", - " return frame\n", - "\n", - " pts = np.round(px).astype(np.int32)\n", - "\n", - " h, w = frame.shape[:2]\n", - " x0, y0 = pts[0]\n", - " if not (0 <= x0 < w and 0 <= y0 < h):\n", - " return frame\n", - "\n", - " cv2.circle(frame, (x0, y0), 4, anchor_color, -1)\n", - " axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] # x,y,z as RGB\n", - "\n", - " for i, color in enumerate(axis_colors, start=1):\n", - " x1, y1 = pts[i]\n", - " if 0 <= x1 < w and 0 <= y1 < h:\n", - " cv2.line(frame, (x0, y0), (x1, y1), color, 2)\n", - " cv2.circle(frame, (x1, y1), 2, color, -1)\n", - "\n", - " cv2.putText(\n", - " frame, label, (x0 + 6, max(12, y0 - 8)),\n", - " cv2.FONT_HERSHEY_SIMPLEX, 0.4, anchor_color, 1, cv2.LINE_AA,\n", - " )\n", - " return frame\n", - "\n", - " vis = _draw_rotation_at_palm(vis, left_xyz, left_ypr, \"L rot\", (255, 180, 80))\n", - " vis = _draw_rotation_at_palm(vis, right_xyz, right_ypr, \"R rot\", (80, 180, 255))\n", - " vis = _draw_axis_color_legend(vis)\n", - " return vis" - ] - }, { "cell_type": "code", "execution_count": null, @@ -266,11 +90,8 @@ "outputs": [], "source": [ "# Separate YPR visualization preview\n", - "image_key = \"images.front_1\"\n", - "action_key = \"actions_cartesian\"\n", - "\n", "for batch in loader:\n", - " vis_ypr = viz_batch_ypr(batch, image_key=image_key, action_key=action_key, intrinsics_key=\"base\")\n", + " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"palm_axes\")\n", " mpy.show_image(vis_ypr)\n", " break" ] @@ -282,12 +103,9 @@ "metadata": {}, "outputs": [], "source": [ - "image_key = \"images.front_1\"\n", - "action_key = \"actions_cartesian\"\n", - "\n", "images = []\n", "for i, batch in enumerate(loader):\n", - " vis = viz_batch(batch, image_key=image_key, action_key=action_key, intrinsics_key=\"base\")\n", + " vis = Eva.viz_transformed_batch(batch, mode=\"palm_traj\")\n", " images.append(vis)\n", " if i > 10:\n", " break\n", @@ -315,25 +133,13 @@ "\n", "intrinsics_key = \"base\"\n", "\n", - "key_map = {\n", - " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", - " \"right.obs_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\"},\n", - " \"left.obs_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\"},\n", - " \"right.action_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\", \"horizon\": 30},\n", - " \"left.action_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\", \"horizon\": 30},\n", - " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", - "}\n", - "\n", - "ACTION_CHUNK_LENGTH = 100\n", - "ACTION_STRIDE = 3\n", + "key_map = Aria.get_keymap()\n", + "transform_list = Aria.get_transform_list()\n", "\n", "resolver = S3EpisodeResolver(\n", " temp_dir,\n", " key_map=key_map,\n", - " transform_list=build_aria_bimanual_transform_list(\n", - " chunk_length=ACTION_CHUNK_LENGTH,\n", - " stride=ACTION_STRIDE,\n", - " )\n", + " transform_list=transform_list,\n", ")\n", "\n", "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"} #aria\n", @@ -353,15 +159,9 @@ "metadata": {}, "outputs": [], "source": [ - "image_key = \"images.front_1\"\n", - "action_key = \"actions_cartesian\"\n", - "\n", "ims = []\n", "for i, batch in enumerate(loader):\n", - " first_img = batch[image_key][0].detach().cpu().permute(1, 2, 0).numpy()\n", - " first_img = (first_img * 255.0).clip(0, 255).astype(np.uint8)\n", - "\n", - " vis = viz_batch(batch, image_key=image_key, action_key=action_key, intrinsics_key=intrinsics_key)\n", + " vis = Aria.viz_transformed_batch(batch, mode=\"palm_traj\")\n", " ims.append(vis)\n", " # mpy.show_image(vis)\n", "\n", @@ -382,12 +182,9 @@ "outputs": [], "source": [ "# Aria YPR video (same data loop, YPR overlay)\n", - "image_key = \"images.front_1\"\n", - "action_key = \"actions_cartesian\"\n", - "\n", "ims_ypr = []\n", "for i, batch in enumerate(loader):\n", - " vis_ypr = viz_batch_ypr(batch, image_key=image_key, action_key=action_key, intrinsics_key=\"base\")\n", + " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"palm_axes\")\n", " ims_ypr.append(vis_ypr)\n", " if i > 20:\n", " break\n", @@ -442,6 +239,7 @@ "metadata": {}, "outputs": [], "source": [ + "# ARIA Keypoint Viz\n", "# MANO skeleton edges: (parent, child) for drawing bones\n", "MANO_EDGES = [\n", " (0, 1), (1, 2), (2, 3), (3, 4), # thumb\n", @@ -451,6 +249,15 @@ " (0, 17), (17, 18), (18, 19), (19, 20), # pinky\n", "]\n", "\n", + "# aria configuration\n", + "MANO_EDGES = [\n", + " (5, 6,), (6, 7), (7, 0), # thumb\n", + " (5, 8), (8, 9), (9, 10), (9, 1), # index\n", + " (5, 11), (11, 12), (12, 13), (13, 2), # middle\n", + " (5, 14), (14, 15), (15, 16), (16, 3), # ring\n", + " (5, 17), (17, 18), (18, 19), (19, 4), # pinky\n", + "]\n", + "\n", "FINGER_COLORS = {\n", " \"thumb\": (255, 100, 100), # red\n", " \"index\": (100, 255, 100), # green\n", @@ -459,12 +266,12 @@ " \"pinky\": (255, 100, 255), # magenta\n", "}\n", "FINGER_EDGE_RANGES = [\n", - " (\"thumb\", 0, 4), (\"index\", 4, 8), (\"middle\", 8, 12),\n", - " (\"ring\", 12, 16), (\"pinky\", 16, 20),\n", + " (\"thumb\", 0, 3), (\"index\", 3, 6), (\"middle\", 6, 9),\n", + " (\"ring\", 9, 12), (\"pinky\", 12, 15),\n", "]\n", "\n", "\n", - "def viz_keypoints(batch, image_key=\"images.front_1\"):\n", + "def viz_keypoints(batch, image_key=\"observations.images.front_img_1\"):\n", " \"\"\"Visualize all 21 MANO keypoints per hand, projected onto the image.\"\"\"\n", " # Prepare image\n", " img = batch[image_key][0].detach().cpu()\n", @@ -479,7 +286,7 @@ " if img_np.shape[-1] == 1:\n", " img_np = np.repeat(img_np, 3, axis=-1)\n", "\n", - " intrinsics = INTRINSICS[\"scale\"]\n", + " intrinsics = INTRINSICS[\"base\"]\n", " head_pose = batch[\"obs_head_pose\"][0].detach().cpu().numpy() # (6,)\n", "\n", " # T_head_world: camera pose in world (camera-to-world)\n",