diff --git a/egomimic/scripts/aria_process/aria_to_zarr.py b/egomimic/scripts/aria_process/aria_to_zarr.py index fe403a3f..3e2178a5 100644 --- a/egomimic/scripts/aria_process/aria_to_zarr.py +++ b/egomimic/scripts/aria_process/aria_to_zarr.py @@ -1,76 +1,58 @@ import argparse -from datetime import datetime, timezone +import ctypes +import gc import logging import os -from pathlib import Path +import re import shutil +import subprocess +import threading +import time import traceback -from typing import Any -from egomimic.rldb.zarr.zarr_writer import ZarrWriter -from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME +from contextlib import contextmanager +from datetime import datetime, timezone +from pathlib import Path + import cv2 -import h5py -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -import torch -import gc, ctypes -from enum import Enum +import numpy as np import projectaria_tools.core.sophus as sp - -from egomimic.utils.egomimicUtils import ( - prep_frame, - start_ffmpeg_mp4, - str2bool, - cam_frame_to_cam_pixels, - INTRINSICS, - interpolate_keys, - interpolate_arr, - interpolate_arr_euler, - transform_to_pose, - pose_to_transform, +import psutil +import torch +import torch.nn.functional as F +from aria_utils import ( + compute_orientation_rotation_matrix, + slam_to_rgb, + undistort_to_linear, + cpf_to_rgb ) - -from projectaria_tools.core.calibration import CameraCalibration, DeviceCalibration -from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions - +from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME from projectaria_tools.core import data_provider, mps - -from projectaria_tools.core.mps.utils import get_nearest_hand_tracking_result - from projectaria_tools.core.mps.utils import ( - filter_points_from_confidence, - get_gaze_vector_reprojection, get_nearest_eye_gaze, + get_nearest_hand_tracking_result, get_nearest_pose, ) +from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions from projectaria_tools.core.stream_id import StreamId +from scipy.spatial.transform import Rotation as R -from aria_utils import ( - build_camera_matrix, - compute_orientation_rotation_matrix, - undistort_to_linear, - slam_to_rgb, +from egomimic.rldb.zarr.zarr_writer import ZarrWriter +from egomimic.utils.egomimicUtils import ( + INTRINSICS, + cam_frame_to_cam_pixels, + pose_to_transform, + prep_frame, + start_ffmpeg_mp4, + str2bool, + transform_to_pose, ) -from egomimic.rldb.utils import EMBODIMENT - -import time - -import numpy as np - -import torch -import torch.nn.functional as F - -from scipy.spatial.transform import Rotation as R -import subprocess -import re -import threading -from contextlib import contextmanager -import psutil - _root = psutil.Process(os.getpid()) + def _proc_rss_mb(p: psutil.Process) -> float: - return p.memory_info().rss / (1024 ** 2) + return p.memory_info().rss / (1024**2) + def cgroup_memory_peak_mb() -> float | None: # cgroup v2 @@ -82,7 +64,7 @@ def cgroup_memory_peak_mb() -> float | None: if os.path.exists(p): try: with open(p, "r") as f: - return int(f.read().strip()) / (1024 ** 2) + return int(f.read().strip()) / (1024**2) except (OSError, ValueError): pass return None @@ -101,6 +83,7 @@ def _read_smaps_rollup_kb(pid: int) -> dict[str, int]: out[k] = int(v[0]) return out + def tree_pss_mb() -> float: procs = [_root] try: @@ -121,6 +104,7 @@ def tree_pss_mb() -> float: pass return total_kb / 1024.0 + def tree_mem_mb(include_children: bool = True, use_uss: bool = True) -> float: root = psutil.Process(os.getpid()) procs = [root] @@ -139,7 +123,8 @@ def tree_mem_mb(include_children: bool = True, use_uss: bool = True) -> float: total += p.memory_info().rss except Exception: pass - return total / (1024 ** 2) + return total / (1024**2) + class _Sampler: def __init__(self, interval_s: float = 0.025): @@ -175,7 +160,9 @@ def stop(self): @contextmanager -def mem_section(name: str, sample_interval_s: float = 0.2, plot: bool = True, enabled: bool = False): +def mem_section( + name: str, sample_interval_s: float = 0.2, plot: bool = True, enabled: bool = False +): if not enabled: yield return @@ -192,10 +179,13 @@ def mem_section(name: str, sample_interval_s: float = 0.2, plot: bool = True, en dt = time.time() - t0 peak = max(sampler.mbs) if sampler.mbs else end - print(f"[{name}] end={end:.2f} MB delta={end-start:+.2f} MB peak={peak:.2f} MB time={dt:.2f}s") + print( + f"[{name}] end={end:.2f} MB delta={end-start:+.2f} MB peak={peak:.2f} MB time={dt:.2f}s" + ) if plot and sampler.mbs and sampler.ts: import matplotlib.pyplot as plt + n = min(len(sampler.ts), len(sampler.mbs)) if n > 1: plt.plot(sampler.ts[:n], sampler.mbs[:n]) @@ -205,9 +195,11 @@ def mem_section(name: str, sample_interval_s: float = 0.2, plot: bool = True, en plt.savefig(f"{_safe_name(name)}.png", dpi=150) plt.close() + def _safe_name(s: str) -> str: return re.sub(r"[^a-zA-Z0-9._-]+", "_", s).strip("_") + ## CHANGE THIS TO YOUR DESIRED CACHE FOR HF os.environ["HF_HOME"] = "~/.cache/huggingface" @@ -231,7 +223,8 @@ def _safe_name(s: str) -> str: # actions[..., 4] *= -1 # Multiply y by -1 for second set # return actions -PERMUTE = np.array([[0,0,1], [1,0,0], [0,1,0]]) +PERMUTE = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) + def SE3_permute_rot(T: np.ndarray) -> np.ndarray: """ @@ -242,6 +235,7 @@ def SE3_permute_rot(T: np.ndarray) -> np.ndarray: T[:3, :3] = rot return T + def timestamp_ms_to_episode_hash(timestamp_ms: int) -> str: """ Convert UTC epoch milliseconds -> string like "2026-01-12-03-47-29-664000". @@ -254,7 +248,6 @@ def timestamp_ms_to_episode_hash(timestamp_ms: int) -> str: return dt.strftime("%Y-%m-%d-%H-%M-%S-%f") - def pose_tx_ty_tz_qx_qy_qz_qw_to_SE3(pose): """ pose: iterable [tx, ty, tz, qx, qy, qz, qw] (quat is x,y,z,w) @@ -270,9 +263,9 @@ def pose_tx_ty_tz_qx_qy_qz_qw_to_SE3(pose): def downsample_hwc_uint8_in_chunks( - images: np.ndarray, # (T,H,W,3) uint8 - out_hw=(240, 320), - chunk: int = 256, + images: np.ndarray, # (T,H,W,3) uint8 + out_hw=(240, 320), + chunk: int = 256, ) -> np.ndarray: assert images.dtype == np.uint8 and images.ndim == 4 and images.shape[-1] == 3 T, H, W, C = images.shape @@ -282,7 +275,9 @@ def downsample_hwc_uint8_in_chunks( for s in range(0, T, chunk): e = min(s + chunk, T) - x = torch.from_numpy(images[s:e]).permute(0, 3, 1, 2).to(torch.float32) / 255.0 # (B,3,H,W) + x = ( + torch.from_numpy(images[s:e]).permute(0, 3, 1, 2).to(torch.float32) / 255.0 + ) # (B,3,H,W) x = F.interpolate(x, size=(outH, outW), mode="bilinear", align_corners=False) x = (x * 255.0).clamp(0, 255).to(torch.uint8) # (B,3,outH,outW) out[s:e] = x.permute(0, 2, 3, 1).cpu().numpy() @@ -290,6 +285,7 @@ def downsample_hwc_uint8_in_chunks( return out + def compute_camera_relative_pose(pose, cam_t_inv, cam_offset): """ pose (6,) : np.array @@ -318,7 +314,6 @@ def compute_camera_relative_pose(pose, cam_t_inv, cam_offset): return pose_t - def quat_translation_swap(quat_translation: np.ndarray) -> np.ndarray: """ Swap the quaternion and translation in a (N, 7) array. @@ -331,7 +326,10 @@ def quat_translation_swap(quat_translation: np.ndarray) -> np.ndarray: np.ndarray: (N, 7) array of translation and quaternion """ - return np.concatenate((quat_translation[..., 4:7], quat_translation[..., 0:4]), axis=-1) + return np.concatenate( + (quat_translation[..., 4:7], quat_translation[..., 0:4]), axis=-1 + ) + def get_hand_pose_in_camera_frame(hand_data, cam_t_inv, cam_offset, transform): """ @@ -387,7 +385,7 @@ class AriaVRSExtractor: @staticmethod def process_episode(episode_path, arm: str, low_res=False, benchmark=False): - f""" + """ Extracts all feature keys from a given episode and returns as a dictionary Parameters ---------- @@ -418,11 +416,14 @@ def process_episode(episode_path, arm: str, low_res=False, benchmark=False): hand_tracking_results_path = os.path.join( mps_sample_path, "hand_tracking", "hand_tracking_results.csv" ) - + closed_loop_pose_path = os.path.join( mps_sample_path, "slam", "closed_loop_trajectory.csv" ) + eye_gaze_path = os.path.join( + mps_sample_path, "eye_gaze", "general_eye_gaze.csv" + ) vrs_reader = data_provider.create_vrs_data_provider(str(episode_path)) @@ -432,7 +433,7 @@ def process_episode(episode_path, arm: str, low_res=False, benchmark=False): closed_loop_traj = mps.read_closed_loop_trajectory(closed_loop_pose_path) - device_calibration = vrs_reader.get_device_calibration() + eye_gaze_results = mps.read_eyegaze(eye_gaze_path) time_domain: TimeDomain = TimeDomain.DEVICE_TIME time_query_closest: TimeQueryOptions = TimeQueryOptions.CLOSEST @@ -455,7 +456,7 @@ def process_episode(episode_path, arm: str, low_res=False, benchmark=False): mps_data_paths = mps_data_paths_provider.get_data_paths() mps_reader = mps.MpsDataProvider(mps_data_paths) - rgb_to_device_T = slam_to_rgb(vrs_reader) # aria sophus SE3 + rgb_to_device_T = slam_to_rgb(vrs_reader) # aria sophus SE3 # ee_pose # TODO: this will be useful for the future - when we add rotation and other state keys @@ -468,7 +469,7 @@ def process_episode(episode_path, arm: str, low_res=False, benchmark=False): hand_tracking_results=hand_tracking_results, arm=arm, ) - + hand_keypoints_pose = AriaVRSExtractor.get_hand_keypoints( world_device_T=closed_loop_traj, stream_timestamps_ns=stream_timestamps_ns, @@ -482,6 +483,10 @@ def process_episode(episode_path, arm: str, low_res=False, benchmark=False): stream_timestamps_ns=stream_timestamps_ns, ) + eye_gaze = AriaVRSExtractor.get_eye_gaze( + eye_gaze_results=eye_gaze_results, stream_timestamps_ns=stream_timestamps_ns + ) + # rgb_camera # TODO: this will be useful for the future - when we add other camera modalities camera_key = AriaVRSExtractor.get_cameras("front_img_1")[0] @@ -493,10 +498,11 @@ def process_episode(episode_path, arm: str, low_res=False, benchmark=False): benchmark=benchmark, ) - if low_res: - images = downsample_hwc_uint8_in_chunks(images, out_hw=(240, 320), chunk=256) - + images = downsample_hwc_uint8_in_chunks( + images, out_hw=(240, 320), chunk=256 + ) + # with mem_section("process_episode.torch_from_numpy_permute", sample_interval_s=0.1, plot=False): # images = torch.from_numpy(images).permute(0, 3, 1, 2).float() @@ -509,27 +515,39 @@ def process_episode(episode_path, arm: str, low_res=False, benchmark=False): # with mem_section("process_episode.byte_numpy", sample_interval_s=0.1, plot=False): # images = images.byte().numpy() - + rgb_timestamps_ns = np.array(stream_timestamps_ns["rgb"]) + print(f"[DEBUG] LENGTH BEFORE CLEANING: {len(hand_cartesian_pose)}") - [hand_cartesian_pose, hand_keypoints_pose, head_pose], images = AriaVRSExtractor.clean_data( - poses=[hand_cartesian_pose, hand_keypoints_pose, head_pose], images=images + [hand_cartesian_pose, hand_keypoints_pose, head_pose], images, eye_gaze, rgb_timestamps_ns = ( + AriaVRSExtractor.clean_data( + poses=[hand_cartesian_pose, hand_keypoints_pose, head_pose], + images=images, + eye_gaze=eye_gaze, + timestamps=rgb_timestamps_ns + ) ) # actions, pose, images = AriaVRSExtractor.clean_data_projection(actions=actions, pose=pose, images=images, arm=arm) print(f"[DEBUG] LENGTH AFTER CLEANING: {len(hand_cartesian_pose)}") episode_feats["left.obs_ee_pose"] = hand_cartesian_pose[..., :7] episode_feats["right.obs_ee_pose"] = hand_cartesian_pose[..., 7:] - episode_feats["left.obs_keypoints"] = hand_keypoints_pose[..., 7:7 + 21*3] - episode_feats["right.obs_keypoints"] = hand_keypoints_pose[..., 7 + 21*3 + 7: 7 + 21*3 + 7 + 21*3] + episode_feats["left.obs_keypoints"] = hand_keypoints_pose[..., 7 : 7 + 21 * 3] + episode_feats["right.obs_keypoints"] = hand_keypoints_pose[ + ..., 7 + 21 * 3 + 7 : 7 + 21 * 3 + 7 + 21 * 3 + ] episode_feats["left.obs_wrist_pose"] = hand_keypoints_pose[..., :7] - episode_feats["right.obs_wrist_pose"] = hand_keypoints_pose[..., 7 + 21*3: 7 + 21*3 + 7] + episode_feats["right.obs_wrist_pose"] = hand_keypoints_pose[ + ..., 7 + 21 * 3 : 7 + 21 * 3 + 7 + ] episode_feats["images.front_1"] = images episode_feats["obs_head_pose"] = head_pose + episode_feats["eye_gaze"] = eye_gaze + episode_feats["rgb_timestamps"] = rgb_timestamps_ns return episode_feats @staticmethod - def clean_data(poses, images): + def clean_data(poses, images, eye_gaze, timestamps): """ Clean data Parameters @@ -537,25 +555,28 @@ def clean_data(poses, images): actions : np.array pose : np.array images : np.array + eye_gaze: np.array + timestamps: np.array Returns ------- - actions, pose, images : tuple of np.array + actions, pose, images, eye_gaze, timestamps : tuple of np.array cleaned data """ mask_poses = np.ones(len(poses[0]), dtype=bool) for pose in poses: bad_data_mask = np.any(pose >= 1e8, axis=1) mask_poses = mask_poses & ~bad_data_mask - + for i in range(len(poses)): poses[i] = poses[i][mask_poses] clean_images = images[mask_poses] + eye_gaze = eye_gaze[mask_poses] + timestamps = timestamps[mask_poses] - return poses, clean_images + return poses, clean_images, eye_gaze, timestamps - @staticmethod - def iter_images(episode_path, chunk_length=64, height=720, width=960, focal_mult=2): + def iter_images(episode_path, chunk_length=64, height=720, width=960, focal_mult=2): """ Iterate over images from VRS Parameters @@ -584,7 +605,6 @@ def iter_images(episode_path, chunk_length=64, height=720, width=960, focal_mul images = [] frame_length = len(stream_timestamps_ns["rgb"]) num_batches = frame_length // chunk_length - for t in range(num_batches): batch_images = [] @@ -597,13 +617,17 @@ def iter_images(episode_path, chunk_length=64, height=720, width=960, focal_mul time_query_closest, ) image_t = undistort_to_linear( - vrs_reader, stream_ids, raw_image=sample_frame[0].to_numpy_array(), height=height, width=width, focal_mult=focal_mult + vrs_reader, + stream_ids, + raw_image=sample_frame[0].to_numpy_array(), + height=height, + width=width, + focal_mult=focal_mult, ) batch_images.append(image_t) batch_images = np.array(batch_images) yield batch_images - - + @staticmethod def clean_data_projection( actions, pose, images, arm, CHUNK_LENGTH=CHUNK_LENGTH_ACT @@ -715,10 +739,14 @@ def get_images( image_t = undistort_to_linear( vrs_reader, stream_ids, raw_image=sample_frame[0].to_numpy_array() ) - images.append(image_t) - with mem_section("get_images.list_to_numpy_array", sample_interval_s=0.1, plot=False, enabled=benchmark): + with mem_section( + "get_images.list_to_numpy_array", + sample_interval_s=0.1, + plot=False, + enabled=benchmark, + ): images = np.array(images) return images @@ -751,9 +779,9 @@ def get_hand_keypoints( time_query_closest = TimeQueryOptions.CLOSEST ee_pose = [] - - use_left_hand = (arm == "left" or arm == "bimanual") - use_right_hand = (arm == "right" or arm == "bimanual") + + use_left_hand = arm == "left" or arm == "bimanual" + use_right_hand = arm == "right" or arm == "bimanual" for t in range(frame_length): query_timestamp = stream_timestamps_ns["rgb"][t] hand_tracking_result_t = get_nearest_hand_tracking_result( @@ -770,31 +798,58 @@ def get_hand_keypoints( getattr(hand_tracking_result_t, "left_hand", None), "confidence", -1 ) left_obs_t = np.full(7 + 21 * 3, 1e9) - if use_left_hand and not left_confidence < 0 and world_device_T_t is not None: - left_hand_keypoints = np.stack(hand_tracking_result_t.left_hand.landmark_positions_device, axis=0) - wrist_T = hand_tracking_result_t.left_hand.transform_device_wrist # Sophus SE3 - - world_wrist_T = world_device_T_t @ wrist_T - world_keypoints = (world_device_T_t @ left_hand_keypoints.T).T # keypoints are in device frame - - world_wrist_T = sp.SE3.from_matrix(SE3_permute_rot(world_wrist_T.to_matrix())) - wrist_quat_and_translation = quat_translation_swap(world_wrist_T.to_quat_and_translation()) + if ( + use_left_hand + and not left_confidence < 0 + and world_device_T_t is not None + ): + left_hand_keypoints = np.stack( + hand_tracking_result_t.left_hand.landmark_positions_device, axis=0 + ) + wrist_T = ( + hand_tracking_result_t.left_hand.transform_device_wrist + ) # Sophus SE3 + + world_wrist_T = world_device_T_t @ wrist_T + world_keypoints = ( + world_device_T_t @ left_hand_keypoints.T + ).T # keypoints are in device frame + + world_wrist_T = sp.SE3.from_matrix( + SE3_permute_rot(world_wrist_T.to_matrix()) + ) + wrist_quat_and_translation = quat_translation_swap( + world_wrist_T.to_quat_and_translation() + ) if wrist_quat_and_translation.ndim == 2: wrist_quat_and_translation = wrist_quat_and_translation[0] left_obs_t[:7] = wrist_quat_and_translation left_obs_t[7:] = world_keypoints.flatten() - right_obs_t = np.full(7 + 21 * 3, 1e9) - if use_right_hand and not right_confidence < 0 and world_device_T_t is not None: - right_hand_keypoints = np.stack(hand_tracking_result_t.right_hand.landmark_positions_device, axis=0) - wrist_T = hand_tracking_result_t.right_hand.transform_device_wrist # Sophus SE3 + if ( + use_right_hand + and not right_confidence < 0 + and world_device_T_t is not None + ): + right_hand_keypoints = np.stack( + hand_tracking_result_t.right_hand.landmark_positions_device, axis=0 + ) + wrist_T = ( + hand_tracking_result_t.right_hand.transform_device_wrist + ) # Sophus SE3 - world_wrist_T = world_device_T_t @ wrist_T - world_keypoints = (world_device_T_t @ right_hand_keypoints.T).T # keypoints are in device frame + world_wrist_T = world_device_T_t @ wrist_T + world_keypoints = ( + world_device_T_t @ right_hand_keypoints.T + ).T # keypoints are in device frame - world_wrist_T = sp.SE3.from_matrix(SE3_permute_rot(world_wrist_T.to_matrix())) - wrist_quat_and_translation = quat_translation_swap(world_wrist_T.to_quat_and_translation()) + world_wrist_T = sp.SE3.from_matrix( + SE3_permute_rot(world_wrist_T.to_matrix()) + ) + wrist_quat_and_translation = quat_translation_swap( + world_wrist_T.to_quat_and_translation() + ) if wrist_quat_and_translation.ndim == 2: wrist_quat_and_translation = wrist_quat_and_translation[0] right_obs_t[:7] = wrist_quat_and_translation @@ -811,7 +866,7 @@ def get_hand_keypoints( ee_pose.append(np.ravel(ee_pose_obs_t)) ee_pose = np.array(ee_pose) return ee_pose - + @staticmethod def get_head_pose( world_device_T, @@ -849,14 +904,35 @@ def get_head_pose( head_pose_obs_t = np.full(7, 1e9) if world_device_T_t is not None: world_rgb_T_t = world_device_T_t @ device_rgb_T @ rgbprime_to_rgb_T - head_pose_quat_and_translation = quat_translation_swap(world_rgb_T_t.to_quat_and_translation()) + head_pose_quat_and_translation = quat_translation_swap( + world_rgb_T_t.to_quat_and_translation() + ) if head_pose_quat_and_translation.ndim == 2: head_pose_quat_and_translation = head_pose_quat_and_translation[0] head_pose_obs_t[:7] = head_pose_quat_and_translation - head_pose.append(np.ravel(head_pose_obs_t)) + head_pose.append(np.ravel(head_pose_obs_t)) head_pose = np.array(head_pose) return head_pose - + + @staticmethod + def get_eye_gaze( + eye_gaze_results, + stream_timestamps_ns: dict, + ): + gaze = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + gaze_info = get_nearest_eye_gaze(eye_gaze_results, query_timestamp) + if gaze_info is None: + gaze.append([-100, -100, -100]) + else: + gaze.append([gaze_info.yaw, gaze_info.pitch, gaze_info.depth]) + + gaze = np.array(gaze) + return gaze + @staticmethod def get_ee_pose( world_device_T, @@ -888,11 +964,9 @@ def get_ee_pose( time_domain = TimeDomain.DEVICE_TIME time_query_closest = TimeQueryOptions.CLOSEST - use_left_hand = (arm == "left" or arm == "bimanual") - use_right_hand = (arm == "right" or arm == "bimanual") + use_left_hand = arm == "left" or arm == "bimanual" + use_right_hand = arm == "right" or arm == "bimanual" - - for t in range(frame_length): query_timestamp = stream_timestamps_ns["rgb"][t] hand_tracking_result_t = get_nearest_hand_tracking_result( @@ -909,11 +983,13 @@ def get_ee_pose( left_confidence = getattr( getattr(hand_tracking_result_t, "left_hand", None), "confidence", -1 ) - - left_obs_t = np.full(7, 1e9) - if use_left_hand and not left_confidence < 0 and world_device_T_t is not None: + if ( + use_left_hand + and not left_confidence < 0 + and world_device_T_t is not None + ): left_palm_pose = ( hand_tracking_result_t.left_hand.get_palm_position_device() ) @@ -933,14 +1009,20 @@ def get_ee_pose( left_T_t = sp.SE3.from_matrix(left_T_t) left_T_t = world_device_T_t @ left_T_t left_T_t = sp.SE3.from_matrix(SE3_permute_rot(left_T_t.to_matrix())) - - left_quat_and_translation = quat_translation_swap(left_T_t.to_quat_and_translation()) + + left_quat_and_translation = quat_translation_swap( + left_T_t.to_quat_and_translation() + ) if left_quat_and_translation.ndim == 2: left_quat_and_translation = left_quat_and_translation[0] left_obs_t[:7] = left_quat_and_translation right_obs_t = np.full(7, 1e9) - if use_right_hand and not right_confidence < 0 and world_device_T_t is not None: + if ( + use_right_hand + and not right_confidence < 0 + and world_device_T_t is not None + ): right_palm_pose = ( hand_tracking_result_t.right_hand.get_palm_position_device() ) @@ -960,7 +1042,9 @@ def get_ee_pose( right_T_t = sp.SE3.from_matrix(right_T_t) right_T_t = world_device_T_t @ right_T_t right_T_t = sp.SE3.from_matrix(SE3_permute_rot(right_T_t.to_matrix())) - right_quat_and_translation = quat_translation_swap(right_T_t.to_quat_and_translation()) + right_quat_and_translation = quat_translation_swap( + right_T_t.to_quat_and_translation() + ) if right_quat_and_translation.ndim == 2: right_quat_and_translation = right_quat_and_translation[0] right_obs_t[:7] = right_quat_and_translation @@ -1026,7 +1110,7 @@ def iter_episode_frames( episode_name = episode_path.name # check if episode is timestamped - if not "-" in episode_name: + if "-" not in episode_name: episode_name = timestamp_ms_to_episode_hash(int(episode_name)) try: @@ -1037,7 +1121,9 @@ def iter_episode_frames( for feature_id, _info in features.items(): if feature_id.startswith("observations."): - key = feature_id.split(".", 1)[-1] # "images.front_img_1" / "state.ee_pose" + key = feature_id.split(".", 1)[ + -1 + ] # "images.front_img_1" / "state.ee_pose" value = episode_feats["observations"].get(key, None) else: value = episode_feats.get(feature_id, None) @@ -1051,9 +1137,7 @@ def iter_episode_frames( if image_compressed: img = cv2.imdecode(value[frame_idx], 1) # HWC BGR uint8 frame[feature_id] = ( - torch.from_numpy(img) - .permute(2, 0, 1) - .contiguous() + torch.from_numpy(img).permute(2, 0, 1).contiguous() ) # CHW uint8 else: frame[feature_id] = ( @@ -1074,7 +1158,6 @@ def iter_episode_frames( finally: del episode_feats - @staticmethod def define_features( episode_feats: dict, image_compressed: bool = True, encode_as_video: bool = True @@ -1222,7 +1305,9 @@ def __init__( self.encode_as_videos = encode_as_videos self.benchmark = benchmark if self.benchmark: - print(f"Benchmark mode enabled. This will plot the RAM usage of each section.") + print( + "Benchmark mode enabled. This will plot the RAM usage of each section." + ) self.logger = logging.getLogger(self.__class__.__name__) self.logger.setLevel(logging.INFO) @@ -1243,7 +1328,6 @@ def __init__( self.logger.info(f"#writer processes: {self.image_writer_processes}") self.logger.info(f"#writer threads: {self.image_writer_threads}") - self._mp4_path = None # set from main() if --save-mp4 self._mp4_writer = None # lazy-initialized in extract_episode() self.episode_list = list(self.raw_path.glob("*.vrs")) @@ -1260,28 +1344,36 @@ def __init__( elif self.arm == "left": self.embodiment = "aria_left_arm" - def save_preview_mp4(self, image_frames: list[dict], output_path: Path, fps: int, image_compressed: bool): + def save_preview_mp4( + self, + image_frames: list[dict], + output_path: Path, + fps: int, + image_compressed: bool, + ): """ Save a single half-resolution, web-compatible MP4 using H.264 (libx264). No fallbacks. Requires `ffmpeg` with libx264 on PATH. - + Each frame dict must contain: 'observations.images.front_img_1' -> torch.Tensor (C,H,W) uint8 """ - + imgs = image_frames - + # Compute half-res (force even dims for yuv420p) C, H, W = imgs[0].shape outW, outH = W // 2, H // 2 - if outW % 2: outW -= 1 - if outH % 2: outH -= 1 + if outW % 2: + outW -= 1 + if outH % 2: + outH -= 1 if outW <= 0 or outH <= 0: raise ValueError(f"[MP4] Invalid output size: {outW}x{outH}") output_path = Path(output_path) output_path.parent.mkdir(parents=True, exist_ok=True) - + rgb_frames = [] for chw in imgs: # chw: (C,H,W) uint8, BGR from cv2.imdecode earlier @@ -1404,7 +1496,7 @@ def save_preview_mp4(self, image_frames: list[dict], output_path: Path, fps: int print( f"[MP4] Saved web-compatible H.264 preview via ffmpeg CLI to {output_path}" ) - + def extract_episode_iterative(self, episode_path, task_description: str = ""): """ TODO: Implement the iterative approach to save memory. @@ -1425,30 +1517,45 @@ def extract_episode_iterative(self, episode_path, task_description: str = ""): embodiment=self.embodiment, enable_sharding=False, task="", - ) + ) with writer.write_incremental(total_frames=total_frames) as inc: image_frames = [] - for i, frame in enumerate(AriaVRSExtractor.iter_episode_frames(episode_path, self.features, self.image_compressed, self.arm, self.prestack, self.benchmark)): + for i, frame in enumerate( + AriaVRSExtractor.iter_episode_frames( + episode_path, + self.features, + self.image_compressed, + self.arm, + self.prestack, + self.benchmark, + ) + ): self.buffer.append(frame) if self._mp4_path is not None: image = frame["observations.images.front_img_1"] image_frames.append(image) if len(self.buffer) == EPISODE_LENGTH: - for f in self.buffer: self.dataset.add_frame(f) - - + self.logger.info(f"Saving Episode after {i + 1} frames...") self.dataset.save_episode(task=task_description) self.buffer.clear() if self._mp4_path is not None: ep_stem = Path(episode_path).stem mp4_path = self._mp4_path / f"{ep_stem}_video.mp4" - self.save_preview_mp4(image_frames, mp4_path, self.fps, self.image_compressed) + self.save_preview_mp4( + image_frames, mp4_path, self.fps, self.image_compressed + ) - def extract_episode(self, episode_path, task_description: str = "", output_dir: Path = Path("."), dataset_name: str = ""): + def extract_episode( + self, + episode_path, + task_description: str = "", + output_dir: Path = Path("."), + dataset_name: str = "", + ): """ Extracts frames from an episode and saves them to the dataset. Parameters @@ -1462,14 +1569,14 @@ def extract_episode(self, episode_path, task_description: str = "", output_dir: None """ episode_name = dataset_name - + episode_feats = AriaVRSExtractor.process_episode( episode_path=episode_path, arm=self.arm, benchmark=self.benchmark, ) numeric_data = {} - + image_data = {} for key, value in episode_feats.items(): if "images" in key: @@ -1494,7 +1601,9 @@ def extract_episode(self, episode_path, task_description: str = "", output_dir: mp4_path = output_dir / f"{episode_name}.mp4" W, H = 960, 720 p = start_ffmpeg_mp4(mp4_path, W, H, fps=30, pix_fmt="rgb24") - for video_images in AriaVRSExtractor.iter_images(episode_path, chunk_length=256, height=H, width=W, focal_mult=3): + for video_images in AriaVRSExtractor.iter_images( + episode_path, chunk_length=256, height=H, width=W, focal_mult=3 + ): for image in video_images: image = prep_frame(image, H, W) if image is None: @@ -1503,9 +1612,13 @@ def extract_episode(self, episode_path, task_description: str = "", output_dir: p.stdin.close() p.wait() return zarr_path, mp4_path - - def extract_episodes(self, episode_description: str = "", output_dir: Path = Path("."), dataset_name: str = ""): + def extract_episodes( + self, + episode_description: str = "", + output_dir: Path = Path("."), + dataset_name: str = "", + ): """ Extracts episodes from the episode list and processes them. Parameters @@ -1525,21 +1638,29 @@ def extract_episodes(self, episode_description: str = "", output_dir: Path = Pat with mem_section("extract_episodes", enabled=self.benchmark): for episode_path in self.episode_list: try: - return self.extract_episode(episode_path, task_description=episode_description, output_dir=output_dir, dataset_name=dataset_name) + return self.extract_episode( + episode_path, + task_description=episode_description, + output_dir=output_dir, + dataset_name=dataset_name, + ) except Exception as e: self.logger.error(f"Error processing episode {episode_path}: {e}") traceback.print_exc() continue - + return None - + + def argument_parse(): parser = argparse.ArgumentParser( description="Convert Aria VRS dataset to LeRobot-Robomimic hybrid and push to Hugging Face hub." ) # Required arguments - parser.add_argument("--dataset-name", type=str, required=True, help="Name for dataset") + parser.add_argument( + "--dataset-name", type=str, required=True, help="Name for dataset" + ) parser.add_argument( "--raw-path", type=Path, @@ -1643,7 +1764,12 @@ def main(args): gc.collect() ctypes.CDLL("libc.so.6").malloc_trim(0) # Extract episodes - return converter.extract_episode(episode_description=args.description, output_dir=args.output_dir, dataset_name=args.dataset_name) + return converter.extract_episodes( + episode_description=args.description, + output_dir=args.output_dir, + dataset_name=args.dataset_name, + ) + if __name__ == "__main__": args = argument_parse() diff --git a/egomimic/scripts/aria_process/aria_utils.py b/egomimic/scripts/aria_process/aria_utils.py index 03b5eb48..c91fea82 100644 --- a/egomimic/scripts/aria_process/aria_utils.py +++ b/egomimic/scripts/aria_process/aria_utils.py @@ -126,3 +126,15 @@ def coordinate_frame_to_ypr(x_axis, y_axis, z_axis): if np.isnan(euler_ypr).any(): euler_ypr = np.zeros_like(euler_ypr) return euler_ypr + +def cpf_to_rgb(provider): + """ + Get cpf (eye tracking origin) to rgb camera transform (rotated upright) + provider: vrs data provider + """ + device_calibration = provider.get_device_calibration() + rgb_calibration = device_calibration.get_camera_calib("camera-rgb") + rgbprime_calibration = calibration.rotate_camera_calib_cw90deg(rgb_calibration) + T_device_cpf = device_calibration.get_transform_device_cpf() + T_device_rgb = rgbprime_calibration.get_transform_device_camera() + return T_device_rgb.inverse() @ T_device_cpf diff --git a/egomimic/scripts/zarr_data_viz.ipynb b/egomimic/scripts/zarr_data_viz.ipynb index 87a27406..0b9645c1 100644 --- a/egomimic/scripts/zarr_data_viz.ipynb +++ b/egomimic/scripts/zarr_data_viz.ipynb @@ -1,397 +1,824 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "29aeeb40", - "metadata": {}, - "source": [ - "# Eva Data\n", - "\n", - "This notebook builds a `MultiDataset` containing exactly one `ZarrDataset`, loads one batch, visualizes one image with `mediapy`, and prints the rest of the batch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32d9110f", - "metadata": {}, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "import cv2\n", - "import imageio_ffmpeg\n", - "import mediapy as mpy\n", - "import numpy as np\n", - "import torch\n", - "\n", - "from egomimic.rldb.embodiment.eva import Eva\n", - "from egomimic.rldb.embodiment.human import Aria\n", - "from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset\n", - "from egomimic.rldb.zarr.zarr_dataset_multi import S3EpisodeResolver\n", - "from egomimic.utils.egomimicUtils import (\n", - " INTRINSICS,\n", - " cam_frame_to_cam_pixels,\n", - " nds,\n", - ")\n", - "from egomimic.utils.aws.aws_data_utils import load_env\n", - "\n", - "# Ensure mediapy can find an ffmpeg executable in this environment\n", - "mpy.set_ffmpeg(imageio_ffmpeg.get_ffmpeg_exe())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cc9edba1", - "metadata": {}, - "outputs": [], - "source": [ - "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n", - "load_env()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a4aa1a05", - "metadata": {}, - "outputs": [], - "source": [ - "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n", - "# EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/1767495035712.zarr\")\n", - "\n", - "key_map = Eva.get_keymap()\n", - "transform_list = Eva.get_transform_list()\n", - "\n", - "# Build a MultiDataset with exactly one ZarrDataset inside\n", - "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", - "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n", - "\n", - "# multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n", - "resolver = S3EpisodeResolver(\n", - " TEMP_DIR, key_map=key_map, transform_list=transform_list\n", - ")\n", - "filters = {\n", - " \"episode_hash\": \"2025-12-26-18-07-46-296000\"\n", - "}\n", - "multi_ds = MultiDataset._from_resolver(\n", - " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", - ")\n", - "\n", - "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b72f3bb", - "metadata": {}, - "outputs": [], - "source": [ - "# Separate YPR visualization preview\n", - "for batch in loader:\n", - " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"palm_axes\")\n", - " mpy.show_image(vis_ypr)\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0d8c3da2", - "metadata": {}, - "outputs": [], - "source": [ - "images = []\n", - "for i, batch in enumerate(loader):\n", - " vis = Eva.viz_transformed_batch(batch, mode=\"palm_traj\")\n", - " images.append(vis)\n", - " if i > 10:\n", - " break\n", - "\n", - "mpy.show_video(images, fps=30)" - ] - }, - { - "cell_type": "markdown", - "id": "1a3382f1", - "metadata": {}, - "source": [ - "## Human Datasets\n", - "Mecka, Scale and Aria should all run exactly the same" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b7384468", - "metadata": {}, - "outputs": [], - "source": [ - "temp_dir = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n", - "\n", - "intrinsics_key = \"base\"\n", - "\n", - "key_map = Aria.get_keymap()\n", - "transform_list = Aria.get_transform_list()\n", - "\n", - "resolver = S3EpisodeResolver(\n", - " temp_dir,\n", - " key_map=key_map,\n", - " transform_list=transform_list,\n", - ")\n", - "\n", - "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"} #aria\n", - "# filters = {\"episode_hash\": \"692ee048ef7557106e6c4b8d\"} # mecka\n", - "\n", - "cloudflare_ds = MultiDataset._from_resolver(\n", - " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", - ")\n", - "\n", - "loader = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "af65095a", - "metadata": {}, - "outputs": [], - "source": [ - "ims = []\n", - "for i, batch in enumerate(loader):\n", - " vis = Aria.viz_transformed_batch(batch, mode=\"palm_traj\")\n", - " ims.append(vis)\n", - " # mpy.show_image(vis)\n", - "\n", - " # for k, v in batch.items():\n", - " # print(f\"{k}: {tuple(v.shape)}\")\n", - " \n", - " if i > 10:\n", - " break\n", - "\n", - "mpy.show_video(ims, fps=30)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e6d8d872", - "metadata": {}, - "outputs": [], - "source": [ - "# Aria YPR video (same data loop, YPR overlay)\n", - "ims_ypr = []\n", - "for i, batch in enumerate(loader):\n", - " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"palm_axes\")\n", - " ims_ypr.append(vis_ypr)\n", - " if i > 20:\n", - " break\n", - "\n", - "mpy.show_video(ims_ypr, fps=30)" - ] - }, - { - "cell_type": "markdown", - "id": "efecaba7", - "metadata": {}, - "source": [ - "## Keypoint Visualization" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e39bca03", - "metadata": {}, - "outputs": [], - "source": [ - "# Load Scale episode with raw keypoints (no action chunking needed)\n", - "\n", - "from egomimic.rldb.zarr.action_chunk_transforms import _xyzwxyz_to_matrix\n", - "\n", - "key_map_kp = {\n", - " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", - " \"left.obs_keypoints\": {\"zarr_key\": \"left.obs_keypoints\"},\n", - " \"right.obs_keypoints\": {\"zarr_key\": \"right.obs_keypoints\"},\n", - " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", - "}\n", - "\n", - "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"}\n", - "\n", - "resolver = S3EpisodeResolver(\n", - " temp_dir,\n", - " key_map=key_map\n", - ")\n", - "\n", - "cloudflare_ds = MultiDataset._from_resolver(\n", - " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n", - ")\n", - "\n", - "loader_kp = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "848c6d74", - "metadata": {}, - "outputs": [], - "source": [ - "# ARIA Keypoint Viz\n", - "# MANO skeleton edges: (parent, child) for drawing bones\n", - "MANO_EDGES = [\n", - " (0, 1), (1, 2), (2, 3), (3, 4), # thumb\n", - " (0, 5), (5, 6), (6, 7), (7, 8), # index\n", - " (0, 9), (9, 10), (10, 11), (11, 12), # middle\n", - " (0, 13), (13, 14), (14, 15), (15, 16), # ring\n", - " (0, 17), (17, 18), (18, 19), (19, 20), # pinky\n", - "]\n", - "\n", - "# aria configuration\n", - "MANO_EDGES = [\n", - " (5, 6,), (6, 7), (7, 0), # thumb\n", - " (5, 8), (8, 9), (9, 10), (9, 1), # index\n", - " (5, 11), (11, 12), (12, 13), (13, 2), # middle\n", - " (5, 14), (14, 15), (15, 16), (16, 3), # ring\n", - " (5, 17), (17, 18), (18, 19), (19, 4), # pinky\n", - "]\n", - "\n", - "FINGER_COLORS = {\n", - " \"thumb\": (255, 100, 100), # red\n", - " \"index\": (100, 255, 100), # green\n", - " \"middle\": (100, 100, 255), # blue\n", - " \"ring\": (255, 255, 100), # yellow\n", - " \"pinky\": (255, 100, 255), # magenta\n", - "}\n", - "FINGER_EDGE_RANGES = [\n", - " (\"thumb\", 0, 3), (\"index\", 3, 6), (\"middle\", 6, 9),\n", - " (\"ring\", 9, 12), (\"pinky\", 12, 15),\n", - "]\n", - "\n", - "\n", - "def viz_keypoints(batch, image_key=\"observations.images.front_img_1\"):\n", - " \"\"\"Visualize all 21 MANO keypoints per hand, projected onto the image.\"\"\"\n", - " # Prepare image\n", - " img = batch[image_key][0].detach().cpu()\n", - " if img.shape[0] in (1, 3):\n", - " img = img.permute(1, 2, 0)\n", - " img_np = img.numpy()\n", - " if img_np.dtype != np.uint8:\n", - " if img_np.max() <= 1.0:\n", - " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", - " else:\n", - " img_np = img_np.clip(0, 255).astype(np.uint8)\n", - " if img_np.shape[-1] == 1:\n", - " img_np = np.repeat(img_np, 3, axis=-1)\n", - "\n", - " intrinsics = INTRINSICS[\"base\"]\n", - " head_pose = batch[\"obs_head_pose\"][0].detach().cpu().numpy() # (6,)\n", - "\n", - " # T_head_world: camera pose in world (camera-to-world)\n", - " # We need world-to-camera = inv(T_head_world)\n", - " T_head_world = _xyzwxyz_to_matrix(head_pose[None, :])[0] # (4, 4)\n", - " T_world_to_cam = np.linalg.inv(T_head_world)\n", - "\n", - " vis = img_np.copy()\n", - " h, w = vis.shape[:2]\n", - "\n", - " for hand, dot_color in [(\"left\", (0, 120, 255)), (\"right\", (255, 80, 0))]:\n", - " kps_key = f\"{hand}.obs_keypoints\"\n", - " if kps_key not in batch:\n", - " continue\n", - " kps_flat = batch[kps_key][0].detach().cpu().numpy() # (63,)\n", - " kps_world = kps_flat.reshape(21, 3)\n", - "\n", - " # Skip if keypoints are all zero (invalid, clamped from 1e9)\n", - " if np.allclose(kps_world, 0.0, atol=1e-3):\n", - " continue\n", - "\n", - " # World -> camera frame\n", - " kps_h = np.concatenate([kps_world, np.ones((21, 1))], axis=1) # (21, 4)\n", - " kps_cam = (T_world_to_cam @ kps_h.T).T[:, :3] # (21, 3)\n", - "\n", - " # Camera frame -> pixels\n", - " kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (21, 3+)\n", - "\n", - " # Identify valid keypoints (z > 0 and in image bounds)\n", - " valid = (kps_cam[:, 2] > 0.01)\n", - " valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w)\n", - " valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h)\n", - "\n", - " # Draw skeleton edges (colored by finger)\n", - " for finger, start, end in FINGER_EDGE_RANGES:\n", - " color = FINGER_COLORS[finger]\n", - " for edge_idx in range(start, end):\n", - " i, j = MANO_EDGES[edge_idx]\n", - " if valid[i] and valid[j]:\n", - " p1 = (int(kps_px[i, 0]), int(kps_px[i, 1]))\n", - " p2 = (int(kps_px[j, 0]), int(kps_px[j, 1]))\n", - " cv2.line(vis, p1, p2, color, 2)\n", - "\n", - " # Draw keypoint dots on top\n", - " for k in range(21):\n", - " if valid[k]:\n", - " center = (int(kps_px[k, 0]), int(kps_px[k, 1]))\n", - " cv2.circle(vis, center, 4, dot_color, -1)\n", - " cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border\n", - "\n", - " # Label wrist\n", - " if valid[0]:\n", - " wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6)\n", - " cv2.putText(vis, f\"{hand[0].upper()}\", wrist_px,\n", - " cv2.FONT_HERSHEY_SIMPLEX, 0.5, dot_color, 2)\n", - "\n", - " return vis" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "75dbfa95", - "metadata": {}, - "outputs": [], - "source": [ - "# Render keypoint video\n", - "ims_kp = []\n", - "for i, batch_kp in enumerate(loader_kp):\n", - " vis = viz_keypoints(batch_kp)\n", - " ims_kp.append(vis)\n", - " if i > 10:\n", - " break\n", - "\n", - "mpy.show_video(ims_kp, fps=30)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8f4fbaec", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "emimic (3.11.14)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "9fac2748", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "29aeeb40", + "metadata": {}, + "source": [ + "# Eva Data\n", + "\n", + "This notebook builds a `MultiDataset` containing exactly one `ZarrDataset`, loads one batch, visualizes one image with `mediapy`, and prints the rest of the batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32d9110f", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "import cv2\n", + "import imageio_ffmpeg\n", + "import mediapy as mpy\n", + "import numpy as np\n", + "import torch\n", + "from scipy.spatial.transform import Rotation as R\n", + "\n", + "from egomimic.rldb.zarr.action_chunk_transforms import (\n", + " _matrix_to_xyzwxyz,\n", + " build_aria_transform_list,\n", + " build_eva_transform_list,\n", + ")\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset\n", + "from egomimic.utils.egomimicUtils import (\n", + " EXTRINSICS,\n", + " INTRINSICS,\n", + " cam_frame_to_cam_pixels,\n", + " draw_actions,\n", + " nds,\n", + ")\n", + "\n", + "# Ensure mediapy can find an ffmpeg executable in this environment\n", + "mpy.set_ffmpeg(imageio_ffmpeg.get_ffmpeg_exe())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4aa1a05", + "metadata": {}, + "outputs": [], + "source": [ + "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n", + "EPISODE_PATH = Path(\"/home/ubuntu/download_eva/1764215784190.zarr\")\n", + "\n", + "key_map = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"images.right_wrist\": {\"zarr_key\": \"images.right_wrist\"},\n", + " \"images.left_wrist\": {\"zarr_key\": \"images.left_wrist\"},\n", + " \"right.obs_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\"},\n", + " \"right.obs_gripper\": {\"zarr_key\": \"right.gripper\"},\n", + " # \"left.obs_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\"},\n", + " # \"left.obs_gripper\": {\"zarr_key\": \"left.gripper\"},\n", + " \"right.gripper\": {\"zarr_key\": \"right.gripper\", \"horizon\": 45},\n", + " # \"left.gripper\": {\"zarr_key\": \"left.gripper\", \"horizon\": 45},\n", + " \"right.cmd_ee_pose\": {\"zarr_key\": \"right.cmd_ee_pose\", \"horizon\": 45},\n", + " # \"left.cmd_ee_pose\": {\"zarr_key\": \"left.cmd_ee_pose\", \"horizon\": 45},\n", + "}\n", + "\n", + "ACTION_CHUNK_LENGTH = 100\n", + "ACTION_STRIDE = 1 # set to 3 for Aria-style anchor sampling\n", + "\n", + "extrinsics = EXTRINSICS[\"x5Dec13_2\"]\n", + "left_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics[\"left\"][None, :])[0]\n", + "right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics[\"right\"][None, :])[0]\n", + "\n", + "transform_list = build_eva_transform_list(\n", + " which=\"right\",\n", + " chunk_length=ACTION_CHUNK_LENGTH,\n", + " stride=ACTION_STRIDE,\n", + " # left_extra_batch_key={\"left_extrinsics_pose\": left_extrinsics_pose},\n", + " # right_extra_batch_key={\"right_extrinsics_pose\": right_extrinsics_pose},\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c7fbf37", + "metadata": {}, + "outputs": [], + "source": [ + "# Build a MultiDataset with exactly one ZarrDataset inside\n", + "single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", + "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n", + "multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n", + "\n", + "print(\"len(single_ds):\", len(single_ds))\n", + "print(\"len(multi_ds):\", len(multi_ds))\n", + "\n", + "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86338e96", + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(loader))\n", + "nds(batch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e4626d8", + "metadata": {}, + "outputs": [], + "source": [ + "def _split_action_pose(actions):\n", + " # 14D layout: [L xyz ypr g, R xyz ypr g]\n", + " # 12D layout: [L xyz ypr, R xyz ypr]\n", + " if actions.shape[-1] == 14:\n", + " left_xyz = actions[:, :3]\n", + " left_ypr = actions[:, 3:6]\n", + " right_xyz = actions[:, 7:10]\n", + " right_ypr = actions[:, 10:13]\n", + " elif actions.shape[-1] == 12:\n", + " left_xyz = actions[:, :3]\n", + " left_ypr = actions[:, 3:6]\n", + " right_xyz = actions[:, 6:9]\n", + " right_ypr = actions[:, 9:12]\n", + " else:\n", + " raise ValueError(f\"Unsupported action dim {actions.shape[-1]}\")\n", + " return left_xyz, left_ypr, right_xyz, right_ypr\n", + "\n", + "\n", + "def viz_batch(batch, image_key, action_key, intrinsics_key):\n", + " # Image: (C,H,W) -> (H,W,C)\n", + " img = batch[image_key][0].detach().cpu()\n", + " if img.shape[0] in (1, 3):\n", + " img = img.permute(1, 2, 0)\n", + " img_np = img.numpy()\n", + "\n", + " if img_np.dtype != np.uint8:\n", + " if img_np.max() <= 1.0:\n", + " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", + " else:\n", + " img_np = img_np.clip(0, 255).astype(np.uint8)\n", + " if img_np.shape[-1] == 1:\n", + " img_np = np.repeat(img_np, 3, axis=-1)\n", + "\n", + " intrinsics = INTRINSICS[intrinsics_key]\n", + " actions = batch[action_key][0].detach().cpu().numpy()\n", + " left_xyz, _, right_xyz, _ = _split_action_pose(actions)\n", + "\n", + " vis = draw_actions(\n", + " img_np.copy(), type=\"xyz\", color=\"Blues\",\n", + " actions=left_xyz, extrinsics=None, intrinsics=intrinsics, arm=\"left\"\n", + " )\n", + " vis = draw_actions(\n", + " vis, type=\"xyz\", color=\"Reds\",\n", + " actions=right_xyz, extrinsics=None, intrinsics=intrinsics, arm=\"right\"\n", + " )\n", + " return vis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89be6d5b", + "metadata": {}, + "outputs": [], + "source": [ + "def viz_batch_ypr(arm, batch, image_key, action_key, intrinsics_key, axis_len_m=0.04):\n", + " img = batch[image_key][0].detach().cpu()\n", + " if img.shape[0] in (1, 3):\n", + " img = img.permute(1, 2, 0)\n", + " img_np = img.numpy()\n", + "\n", + " if img_np.dtype != np.uint8:\n", + " if img_np.max() <= 1.0:\n", + " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", + " else:\n", + " img_np = img_np.clip(0, 255).astype(np.uint8)\n", + " if img_np.shape[-1] == 1:\n", + " img_np = np.repeat(img_np, 3, axis=-1)\n", + "\n", + " intrinsics = INTRINSICS[intrinsics_key]\n", + " actions = batch[action_key][0].detach().cpu().numpy()\n", + " left_xyz, left_ypr, right_xyz, right_ypr = None, None, None, None\n", + " if arm == \"both\":\n", + " left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions)\n", + " elif arm == \"left\":\n", + " left_xyz, left_ypr = actions[:, :3], actions[:, 3:6]\n", + " elif arm == \"right\":\n", + " right_xyz, right_ypr = actions[:, :3], actions[:, 3:6]\n", + "\n", + " vis = img_np.copy()\n", + "\n", + " def _draw_axis_color_legend(frame):\n", + " h, w = frame.shape[:2]\n", + " x_right = w - 12\n", + " y_start = 14\n", + " y_step = 12\n", + " line_len = 24\n", + " axis_legend = [\n", + " (\"x\", (255, 0, 0)),\n", + " (\"y\", (0, 255, 0)),\n", + " (\"z\", (0, 0, 255)),\n", + " ]\n", + " for i, (name, color) in enumerate(axis_legend):\n", + " y = y_start + i * y_step\n", + " x0 = x_right - line_len\n", + " x1 = x_right\n", + " cv2.line(frame, (x0, y), (x1, y), color, 3)\n", + " cv2.putText(\n", + " frame, name, (x0 - 12, y + 4),\n", + " cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1, cv2.LINE_AA,\n", + " )\n", + " return frame\n", + "\n", + " def _draw_rotation_at_palm(frame, xyz_seq, ypr_seq, label, anchor_color):\n", + " if len(xyz_seq) == 0 or len(ypr_seq) == 0:\n", + " return frame\n", + "\n", + " palm_xyz = xyz_seq[0]\n", + " palm_ypr = ypr_seq[0]\n", + " rot = R.from_euler(\"ZYX\", palm_ypr, degrees=False).as_matrix()\n", + " # print(np.linalg.det(rot))\n", + "\n", + " axis_points_cam = np.vstack([\n", + " palm_xyz,\n", + " palm_xyz + rot[:, 0] * axis_len_m,\n", + " palm_xyz + rot[:, 1] * axis_len_m,\n", + " palm_xyz + rot[:, 2] * axis_len_m,\n", + " ])\n", + "\n", + " px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2]\n", + " if not np.isfinite(px).all():\n", + " return frame\n", + "\n", + " pts = np.round(px).astype(np.int32)\n", + "\n", + " h, w = frame.shape[:2]\n", + " x0, y0 = pts[0]\n", + " if not (0 <= x0 < w and 0 <= y0 < h):\n", + " return frame\n", + "\n", + " cv2.circle(frame, (x0, y0), 4, anchor_color, -1)\n", + " axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] # x,y,z as RGB\n", + "\n", + " for i, color in enumerate(axis_colors, start=1):\n", + " x1, y1 = pts[i]\n", + " if 0 <= x1 < w and 0 <= y1 < h:\n", + " cv2.line(frame, (x0, y0), (x1, y1), color, 2)\n", + " cv2.circle(frame, (x1, y1), 2, color, -1)\n", + "\n", + " cv2.putText(\n", + " frame, label, (x0 + 6, max(12, y0 - 8)),\n", + " cv2.FONT_HERSHEY_SIMPLEX, 0.4, anchor_color, 1, cv2.LINE_AA,\n", + " )\n", + " return frame\n", + "\n", + " if left_xyz is not None and left_ypr is not None:\n", + " vis = _draw_rotation_at_palm(vis, left_xyz, left_ypr, \"L rot\", (255, 180, 80))\n", + " if right_xyz is not None and right_ypr is not None:\n", + " vis = _draw_rotation_at_palm(vis, right_xyz, right_ypr, \"R rot\", (80, 180, 255))\n", + " vis = _draw_axis_color_legend(vis)\n", + " return vis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b72f3bb", + "metadata": {}, + "outputs": [], + "source": [ + "# Separate YPR visualization preview\n", + "image_key = \"images.front_1\"\n", + "action_key = \"actions_cartesian\"\n", + "\n", + "for batch in loader:\n", + " vis_ypr = viz_batch_ypr(arm=\"right\",batch=batch, image_key=image_key, action_key=action_key, intrinsics_key=\"base\")\n", + " mpy.show_image(vis_ypr)\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d8c3da2", + "metadata": {}, + "outputs": [], + "source": [ + "image_key = \"images.front_1\"\n", + "action_key = \"actions_cartesian\"\n", + "\n", + "images = []\n", + "for i, batch in enumerate(loader):\n", + " vis = viz_batch(batch, image_key=image_key, action_key=action_key, intrinsics_key=\"base\")\n", + " images.append(vis)\n", + " if i > 100:\n", + " break\n", + "\n", + "mpy.show_video(images, fps=30)" + ] + }, + { + "cell_type": "markdown", + "id": "1a3382f1", + "metadata": {}, + "source": [ + "## Human Datasets\n", + "Mecka, Scale and Aria should all run exactly the same" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38100d31", + "metadata": {}, + "outputs": [], + "source": [ + "from egomimic.utils.aws.aws_sql import timestamp_ms_to_episode_hash\n", + "\n", + "timestamp_ms_to_episode_hash(1764285228498)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8693d01c", + "metadata": {}, + "outputs": [], + "source": [ + "# Aria-style chunking example: horizon=30 contiguous frames, sample anchors every 3 -> 10 points, then interpolate to 100.\n", + "\n", + "# EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/scale/697a9070da7b91acaf3f2d88_episode_000000.zarr\") # Scale\n", + "# intrinsics_key = \"scale\"\n", + "\n", + "EPISODE_PATH = Path(\"/home/ubuntu/download_aria/1764285228498.zarr\") # Aria\n", + "intrinsics_key = \"base\"\n", + "\n", + "\n", + "key_map = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"right.obs_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\"},\n", + " \"left.obs_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\"},\n", + " \"right.action_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\", \"horizon\": 30},\n", + " \"left.action_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\", \"horizon\": 30},\n", + " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", + "}\n", + "\n", + "ACTION_CHUNK_LENGTH = 100\n", + "ACTION_STRIDE = 3\n", + "\n", + "transform_list = build_aria_transform_list(\n", + " which=\"both\",\n", + " chunk_length=ACTION_CHUNK_LENGTH,\n", + " stride=ACTION_STRIDE,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b4c03ca", + "metadata": {}, + "outputs": [], + "source": [ + "# Full MultiDataset via LocalEpisodeResolver (mirrors test_multi_zarr.yaml)\n", + "from egomimic.rldb.zarr.action_chunk_transforms import (\n", + " build_aria_bimanual_transform_list,\n", + ")\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import LocalEpisodeResolver, MultiDataset\n", + "\n", + "SCALE_FOLDER = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/scale/2026-02-24-01-49-24-166324\")\n", + "\n", + "key_map = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"right.obs_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\"},\n", + " \"left.obs_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\"},\n", + " \"right.action_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\", \"horizon\": 30},\n", + " \"left.action_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\", \"horizon\": 30},\n", + " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", + "}\n", + "\n", + "transform_list = build_aria_bimanual_transform_list(\n", + " stride=1,\n", + ")\n", + "\n", + "resolver = LocalEpisodeResolver(\n", + " folder_path=SCALE_FOLDER,\n", + " key_map=key_map,\n", + " transform_list=transform_list,\n", + ")\n", + "\n", + "multi_ds = MultiDataset._from_resolver(resolver, mode=\"total\")\n", + "print(f\"MultiDataset total frames: {len(multi_ds)}\")\n", + "print(f\"Underlying episodes: {list(multi_ds.datasets.keys())}\")\n", + "\n", + "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)\n", + "batch = next(iter(loader))\n", + "print(\"Batch keys:\", list(batch.keys()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1da784ea", + "metadata": {}, + "outputs": [], + "source": [ + "# Build a MultiDataset with exactly one ZarrDataset inside\n", + "single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", + "#single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n", + "multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n", + "\n", + "print(\"len(single_ds):\", len(single_ds))\n", + "print(\"len(multi_ds):\", len(multi_ds))\n", + "\n", + "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)\n", + "# batch = next(iter(loader))\n", + "\n", + "# print(\"Batch keys:\", list(batch.keys()))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ee57d15", + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(loader))\n", + "print(batch.keys())\n", + "for k, v in batch.items():\n", + " print(k, v.shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "507a1fd6", + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(loader))\n", + "nds(batch)\n", + "print(\"Batch keys:\", list(batch.keys()))\n", + "print(batch[\"right.action_ee_pose\"][0, 0])\n", + "print(batch[\"left.action_ee_pose\"][0, 0])\n", + "print(batch[\"obs_head_pose\"][0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e94799c", + "metadata": {}, + "outputs": [], + "source": [ + "left_hand_pose = []\n", + "right_hand_pose = []\n", + "head_pose = []\n", + "for i, batch in enumerate(loader):\n", + " left_hand_pose.append(batch[\"left.action_ee_pose\"][0, 0])\n", + " right_hand_pose.append(batch[\"right.action_ee_pose\"][0, 0])\n", + " head_pose.append(batch[\"obs_head_pose\"][0])\n", + " \n", + " if i > 400:\n", + " break\n", + "left_hand_pose = np.array(left_hand_pose)\n", + "right_hand_pose = np.array(right_hand_pose)\n", + "head_pose = np.array(head_pose)\n", + "\n", + "# chunk the pose to actions(N, 100, 3)\n", + "left_hand_pose_actions = []\n", + "right_hand_pose_actions = []\n", + "head_pose_actions = []\n", + "for i in range(left_hand_pose.shape[0] - 100):\n", + " action_left_hand = left_hand_pose[i:i+100, :]\n", + " action_right_hand = right_hand_pose[i:i+100, :]\n", + " action_head = head_pose[i:i+100, :]\n", + " left_hand_pose_actions.append(action_left_hand)\n", + " right_hand_pose_actions.append(action_right_hand)\n", + " head_pose_actions.append(action_head)\n", + "left_hand_pose_actions = np.array(left_hand_pose_actions)\n", + "right_hand_pose_actions = np.array(right_hand_pose_actions)\n", + "head_pose_actions = np.array(head_pose_actions)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6517f061", + "metadata": {}, + "outputs": [], + "source": [ + "from egomimic.utils.egomimicUtils import render_3d_traj_frames\n", + "\n", + "frames = render_3d_traj_frames([left_hand_pose_actions, right_hand_pose_actions, head_pose_actions], labels=[\"left hand\", \"right hand\", \"head\"], stride=10)\n", + "mpy.show_video(frames, fps=30)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af65095a", + "metadata": {}, + "outputs": [], + "source": [ + "image_key = \"images.front_1\"\n", + "action_key = \"actions_cartesian\"\n", + "\n", + "ims = []\n", + "for i, batch in enumerate(loader):\n", + " first_img = batch[image_key][0].detach().cpu().permute(1, 2, 0).numpy()\n", + " first_img = (first_img * 255.0).clip(0, 255).astype(np.uint8)\n", + "\n", + " vis = viz_batch(batch, image_key=image_key, action_key=action_key, intrinsics_key=intrinsics_key)\n", + " ims.append(vis)\n", + " # mpy.show_image(vis)\n", + "\n", + " # for k, v in batch.items():\n", + " # print(f\"{k}: {tuple(v.shape)}\")\n", + " \n", + " if i > 200:\n", + " break\n", + "\n", + "mpy.show_video(ims, fps=30)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6d8d872", + "metadata": {}, + "outputs": [], + "source": [ + "# Aria YPR video (same data loop, YPR overlay)\n", + "image_key = \"images.front_1\"\n", + "action_key = \"actions_cartesian\"\n", + "\n", + "ims_ypr = []\n", + "for i, batch in enumerate(loader):\n", + " vis_ypr = viz_batch_ypr(arm=\"both\",batch=batch, image_key=image_key, action_key=action_key, intrinsics_key=\"base\")\n", + " ims_ypr.append(vis_ypr)\n", + " if i > 600:\n", + " break\n", + "\n", + "mpy.show_video(ims_ypr, fps=30)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36f120b8", + "metadata": {}, + "outputs": [], + "source": [ + "batch[\"actions_cartesian\"][0, 0]\n" + ] + }, + { + "cell_type": "markdown", + "id": "efecaba7", + "metadata": {}, + "source": [ + "## Keypoint Visualization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e39bca03", + "metadata": {}, + "outputs": [], + "source": [ + "# Load Scale episode with raw keypoints (no action chunking needed)\n", + "\n", + "from egomimic.rldb.zarr.action_chunk_transforms import _xyzwxyz_to_matrix\n", + "\n", + "EPISODE_PATH_KP = EPISODE_PATH\n", + "\n", + "key_map_kp = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"left.obs_keypoints\": {\"zarr_key\": \"left.obs_keypoints\"},\n", + " \"right.obs_keypoints\": {\"zarr_key\": \"right.obs_keypoints\"},\n", + " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", + "}\n", + "\n", + "single_ds_kp = ZarrDataset(Episode_path=EPISODE_PATH_KP, key_map=key_map_kp)\n", + "multi_ds_kp = MultiDataset(datasets={\"single_episode\": single_ds_kp}, mode=\"total\")\n", + "loader_kp = torch.utils.data.DataLoader(multi_ds_kp, batch_size=1, shuffle=False)\n", + "print(f\"Keypoint dataset: {len(single_ds_kp)} frames\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "848c6d74", + "metadata": {}, + "outputs": [], + "source": [ + "# MANO skeleton edges: (parent, child) for drawing bones\n", + "MANO_EDGES = [\n", + " (0, 1), (1, 2), (2, 3), (3, 4), # thumb\n", + " (0, 5), (5, 6), (6, 7), (7, 8), # index\n", + " (0, 9), (9, 10), (10, 11), (11, 12), # middle\n", + " (0, 13), (13, 14), (14, 15), (15, 16), # ring\n", + " (0, 17), (17, 18), (18, 19), (19, 20), # pinky\n", + "]\n", + "\n", + "FINGER_COLORS = {\n", + " \"thumb\": (255, 100, 100), # red\n", + " \"index\": (100, 255, 100), # green\n", + " \"middle\": (100, 100, 255), # blue\n", + " \"ring\": (255, 255, 100), # yellow\n", + " \"pinky\": (255, 100, 255), # magenta\n", + "}\n", + "FINGER_EDGE_RANGES = [\n", + " (\"thumb\", 0, 4), (\"index\", 4, 8), (\"middle\", 8, 12),\n", + " (\"ring\", 12, 16), (\"pinky\", 16, 20),\n", + "]\n", + "\n", + "\n", + "def viz_keypoints(batch, image_key=\"images.front_1\"):\n", + " \"\"\"Visualize all 21 MANO keypoints per hand, projected onto the image.\"\"\"\n", + " # Prepare image\n", + " img = batch[image_key][0].detach().cpu()\n", + " if img.shape[0] in (1, 3):\n", + " img = img.permute(1, 2, 0)\n", + " img_np = img.numpy()\n", + " if img_np.dtype != np.uint8:\n", + " if img_np.max() <= 1.0:\n", + " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", + " else:\n", + " img_np = img_np.clip(0, 255).astype(np.uint8)\n", + " if img_np.shape[-1] == 1:\n", + " img_np = np.repeat(img_np, 3, axis=-1)\n", + "\n", + " intrinsics = INTRINSICS[\"scale\"]\n", + " head_pose = batch[\"obs_head_pose\"][0].detach().cpu().numpy() # (6,)\n", + "\n", + " # T_head_world: camera pose in world (camera-to-world)\n", + " # We need world-to-camera = inv(T_head_world)\n", + " T_head_world = _xyzwxyz_to_matrix(head_pose[None, :])[0] # (4, 4)\n", + " T_world_to_cam = np.linalg.inv(T_head_world)\n", + "\n", + " vis = img_np.copy()\n", + " h, w = vis.shape[:2]\n", + "\n", + " for hand, dot_color in [(\"left\", (0, 120, 255)), (\"right\", (255, 80, 0))]:\n", + " kps_key = f\"{hand}.obs_keypoints\"\n", + " if kps_key not in batch:\n", + " continue\n", + " kps_flat = batch[kps_key][0].detach().cpu().numpy() # (63,)\n", + " kps_world = kps_flat.reshape(21, 3)\n", + "\n", + " # Skip if keypoints are all zero (invalid, clamped from 1e9)\n", + " if np.allclose(kps_world, 0.0, atol=1e-3):\n", + " continue\n", + "\n", + " # World -> camera frame\n", + " kps_h = np.concatenate([kps_world, np.ones((21, 1))], axis=1) # (21, 4)\n", + " kps_cam = (T_world_to_cam @ kps_h.T).T[:, :3] # (21, 3)\n", + "\n", + " # Camera frame -> pixels\n", + " kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (21, 3+)\n", + "\n", + " # Identify valid keypoints (z > 0 and in image bounds)\n", + " valid = (kps_cam[:, 2] > 0.01)\n", + " valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w)\n", + " valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h)\n", + "\n", + " # Draw skeleton edges (colored by finger)\n", + " for finger, start, end in FINGER_EDGE_RANGES:\n", + " color = FINGER_COLORS[finger]\n", + " for edge_idx in range(start, end):\n", + " i, j = MANO_EDGES[edge_idx]\n", + " if valid[i] and valid[j]:\n", + " p1 = (int(kps_px[i, 0]), int(kps_px[i, 1]))\n", + " p2 = (int(kps_px[j, 0]), int(kps_px[j, 1]))\n", + " cv2.line(vis, p1, p2, color, 2)\n", + "\n", + " # Draw keypoint dots on top\n", + " for k in range(21):\n", + " if valid[k]:\n", + " center = (int(kps_px[k, 0]), int(kps_px[k, 1]))\n", + " cv2.circle(vis, center, 4, dot_color, -1)\n", + " cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border\n", + "\n", + " # Label wrist\n", + " if valid[0]:\n", + " wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6)\n", + " cv2.putText(vis, f\"{hand[0].upper()}\", wrist_px,\n", + " cv2.FONT_HERSHEY_SIMPLEX, 0.5, dot_color, 2)\n", + "\n", + " return vis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75dbfa95", + "metadata": {}, + "outputs": [], + "source": [ + "# Render keypoint video\n", + "ims_kp = []\n", + "for i, batch_kp in enumerate(loader_kp):\n", + " vis = viz_keypoints(batch_kp)\n", + " ims_kp.append(vis)\n", + " if i > 200:\n", + " break\n", + "\n", + "mpy.show_video(ims_kp, fps=30)" + ] + }, + { + "cell_type": "markdown", + "id": "b30472c9", + "metadata": {}, + "source": [ + "## Gaze Data (Aria)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d4e7f35a", + "metadata": {}, + "outputs": [], + "source": [ + "from egomimic.utils.egomimicUtils import INTRINSICS, cam_frame_to_cam_pixels, draw_dot_on_frame, get_gaze_endpoint, ARIA_T_RGB_CPF\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset\n", + "import mediapy as mpy\n", + "import numpy as np\n", + "import torch\n", + "from pathlib import Path\n", + "\n", + "EPISODE_PATH_GAZE = Path(\"/coc/flash9/fryan6/data/egoverse_sample/1758390474000.zarr\")\n", + "key_map_gaze = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"eye_gaze\": {\"zarr_key\": \"eye_gaze\"},\n", + "}\n", + "intrinsics = INTRINSICS[\"base\"]\n", + "\n", + "single_ds_gaze = ZarrDataset(Episode_path=EPISODE_PATH_GAZE, key_map=key_map_gaze)\n", + "multi_ds_gaze = MultiDataset(datasets={\"single episode\": single_ds_gaze}, mode=\"total\")\n", + "loader = torch.utils.data.DataLoader(multi_ds_gaze, batch_size=1, shuffle=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69541814", + "metadata": {}, + "outputs": [], + "source": [ + "ims_gaze = []\n", + "for i, batch in enumerate(loader):\n", + " gaze_data = batch['eye_gaze'][0].numpy() # yaw (rads), pitch (rads), depth (m)\n", + " if gaze_data[0] != -100: # value for no gaze estimate\n", + " gaze_point_xyz = get_gaze_endpoint(gaze_data[0], gaze_data[1], gaze_data[2], ARIA_T_RGB_CPF)\n", + " gaze_point_xyz = np.expand_dims(gaze_point_xyz, 0)\n", + " gaze_point_pixel = cam_frame_to_cam_pixels(gaze_point_xyz, intrinsics)\n", + " frame = batch['images.front_1'][0].permute(1,2,0).numpy() * 255\n", + " vis = draw_dot_on_frame(frame, gaze_point_pixel, show=False)\n", + " ims_gaze.append(vis)\n", + " if i > 600:\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "23913d24", + "metadata": {}, + "outputs": [], + "source": [ + "mpy.show_video(ims_gaze, fps=30)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "emimic", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/egomimic/utils/egomimicUtils.py b/egomimic/utils/egomimicUtils.py index 670c3c43..be6134af 100644 --- a/egomimic/utils/egomimicUtils.py +++ b/egomimic/utils/egomimicUtils.py @@ -257,6 +257,15 @@ }, } +ARIA_T_RGB_CPF = np.array( + [ + [-0.99989084, 0.01251132, -0.00786028, 0.05686918], + [-0.01132842, -0.99067146, -0.13580032, 0.00922798], + [-0.009486 , -0.13569645, 0.99070505, -0.01147902], + [0.0, 0.0, 0.0, 1.0] + ] + ) + INTRINSICS = { "base": ARIA_INTRINSICS, "base_half": ARIA_INTRINSICS_HALF, @@ -1597,3 +1606,65 @@ def transform_matrix_to_pose(mat: torch.Tensor) -> torch.Tensor: # Return pose: (B, T, 6) return torch.cat([xyz, ypr], dim=-1) + + +def get_vector_from_yaw_pitch( + yaw_rads: float, + pitch_rads: float, + depth: float | None = None, +) -> np.ndarray: + """ + Convert yaw / pitch angles into a 3D gaze vector in CPF coordinates. + + Args: + yaw_rads: Yaw angle in radians. + pitch_rads: Pitch angle in radians. + depth: Optional gaze distance. If provided, returns a vector with this + magnitude. If None, returns a unit vector. + + Returns: + np.ndarray: (3,) gaze vector in CPF coordinates. + """ + z = 1.0 + x = np.tan(yaw_rads) * z + y = np.tan(pitch_rads) * z + + direction = np.array([x, y, z], dtype=np.float64) + norm = np.linalg.norm(direction) + if norm == 0: + raise ValueError("Zero-length direction vector") + + unit_dir = direction / norm + + if depth is None: + return unit_dir + else: + return unit_dir * depth + + +def get_gaze_endpoint(yaw_rads, pitch_rads, depth, T_cam_cpf): + """ + Compute the 3D gaze endpoint in camera coordinates. + + The gaze originates at the CPF origin, with direction defined by yaw/pitch, + and length set by depth. The endpoint is transformed from CPF to camera + frame using T_cam_cpf. + + Args: + yaw_rads: Yaw angle in radians. + pitch_rads: Pitch angle in radians. + depth: Gaze vector magnitude. + T_cam_cpf: (4, 4) SE(3) homogeneous transform from CPF to camera frame. + + Returns: + np.ndarray: (3,) gaze endpoint in camera coordinates. + """ + gaze_vec_cpf = get_vector_from_yaw_pitch(yaw_rads, pitch_rads, depth) + + T_cam_cpf = np.asarray(T_cam_cpf, dtype=np.float64) + if T_cam_cpf.shape != (4, 4): + raise ValueError(f"T_cam_cpf must be a 4x4 transform, got {T_cam_cpf.shape}") + + endpoint_cpf_h = np.concatenate([gaze_vec_cpf, np.array([1.0], dtype=np.float64)]) + endpoint_cam_h = T_cam_cpf @ endpoint_cpf_h + return endpoint_cam_h[:3]