diff --git a/egomimic/scripts/aria_process/aria-cluster.yaml b/egomimic/scripts/aria_process/aria-cluster.yaml index 83f871d1..0f9e62ef 100644 --- a/egomimic/scripts/aria_process/aria-cluster.yaml +++ b/egomimic/scripts/aria_process/aria-cluster.yaml @@ -11,9 +11,6 @@ auth: ssh_user: ubuntu ssh_private_key: ~/.ssh/rldb-base-pem.pem -file_mounts: - "/home/ubuntu/EgoVerse": "/home/elmo/Documents/projects/EgoVerse" - rsync_exclude: - "emimic/" - ".git/" @@ -22,7 +19,7 @@ rsync_exclude: - "logs/" - "*.parquet" -max_workers: 60 +max_workers: 140 idle_timeout_minutes: 5 available_node_types: @@ -31,10 +28,12 @@ available_node_types: InstanceType: t3a.2xlarge KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 BlockDeviceMappings: - DeviceName: /dev/sda1 Ebs: { VolumeSize: 50 } - resources: { CPU: 4 } + resources: { CPU: 8 } min_workers: 0 max_workers: 0 @@ -43,24 +42,28 @@ available_node_types: InstanceType: r6a.2xlarge KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 BlockDeviceMappings: - DeviceName: /dev/sda1 - Ebs: { VolumeSize: 50 } - resources: { CPU: 8, aria_small: 1 } + Ebs: { VolumeSize: 150 } + resources: { CPU: 2, aria_small: 1 } min_workers: 0 - max_workers: 50 + max_workers: 120 worker_big: node_config: InstanceType: r6a.8xlarge + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac BlockDeviceMappings: - DeviceName: /dev/sda1 - Ebs: { VolumeSize: 100 } # often worth bumping - resources: { CPU: 32, aria_big: 1 } + Ebs: { VolumeSize: 200 } # often worth bumping + resources: { CPU: 8, aria_big: 1 } min_workers: 0 - max_workers: 10 + max_workers: 7 head_node_type: head_node @@ -101,8 +104,6 @@ initialization_commands: pip3 install --no-input -e . setup_commands: - - sudo mkdir -p /mnt/raw /mnt/processed - - | chmod +x ~/EgoVerse/egomimic/utils/aws/setup_secret.sh R2_SECRET_NAME=r2/rldb/credentials DB_SECRET_NAME=rds/appdb/appuser REGION=us-east-2 \ @@ -116,10 +117,10 @@ head_setup_commands: | grep -v ray_worker_gaurdrails.py \ | grep -v ray_worker_gaurdrails.lock ; \ echo 'CRON_TZ=America/New_York'; \ - echo '0 20 * * * flock -n /tmp/run_aria_conversion.lock /bin/bash -lc "set -a; . /home/ubuntu/.egoverse_env; set +a; /usr/bin/python3 ~/EgoVerse/egomimic/scripts/aria_process/run_aria_conversion.py --skip-if-done" >> ~/aria_conversion.log 2>&1'; \ + echo '0 20 * * * flock -n /tmp/run_aria_conversion.lock /bin/bash -lc "set -a; . /home/ubuntu/.egoverse_env; set +a; /usr/bin/python3 ~/EgoVerse/egomimic/scripts/aria_process/run_conversion.py --embodiment aria --skip-if-done --debug" >> ~/aria_conversion.log 2>&1'; \ echo '*/10 * * * * flock -n /tmp/ray_worker_gaurdrails.lock /usr/bin/python3 /home/ubuntu/EgoVerse/egomimic/utils/aws/budget_guardrails/ray_worker_gaurdrails.py >> /home/ubuntu/ray_worker_gaurdrails.log 2>&1') \ | crontab - || true - crontab -l || true # Ray startup commands - configure GCS to detect dead/stale nodes faster -# ray start --head --dashboard-host=0.0.0.0 --dashboard-port=8265 --include-dashboard=true +# ray start --head --dashboard-host=0.0.0.0 --dashboard-port=8265 --include-dashboard=true \ No newline at end of file diff --git a/egomimic/scripts/aria_process/aria_to_zarr.py b/egomimic/scripts/aria_process/aria_to_zarr.py new file mode 100644 index 00000000..220f46cf --- /dev/null +++ b/egomimic/scripts/aria_process/aria_to_zarr.py @@ -0,0 +1,260 @@ +import argparse +import ctypes +import gc +import logging +import traceback +from pathlib import Path + +import numpy as np + +from egomimic.rldb.zarr.zarr_writer import ZarrWriter +from egomimic.scripts.aria_process.aria_utils import AriaVRSExtractor +from egomimic.utils.aws.aws_sql import timestamp_ms_to_episode_hash +from egomimic.utils.egomimicUtils import str2bool +from egomimic.utils.video_utils import save_preview_mp4 + +logger = logging.getLogger(__name__) + + +class DatasetConverter: + """ + A class to convert Aria VRS dataset to Zarr episodes. + Parameters + ---------- + raw_path : Path or str + The path to the raw dataset. + fps : int + Frames per second for the dataset. + arm : str, optional + The arm to process (e.g., 'left', 'right', or 'bimanual'), by default "". + save_mp4 : bool, optional + Whether to save a MP4 of the episode, by default False. + image_compressed : bool, optional + Whether the images are compressed, by default True. + Methods + ------- + extract_episode(episode_path, task_name='', output_dir='.', dataset_name='', chunk_timesteps=100) + Extracts frames from a single episode and saves it with a description. + main(args) + Main function to convert the dataset. + argument_parse() + Parses the command-line arguments. + """ + + def __init__( + self, + raw_path: Path | str, + fps: int, + arm: str = "", + save_mp4: bool = False, + image_compressed: bool = True, + debug: bool = False, + height: int = 480, + width: int = 640, + ): + self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path) + self.fps = fps + self.arm = arm + self.image_compressed = image_compressed + self.save_mp4 = save_mp4 + self.height = height + self.width = width + + self.logger = logging.getLogger(self.__class__.__name__) + self.logger.setLevel(logging.INFO) + + # Add console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s") + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + self.logger.info(f"{'-' * 10} Aria VRS -> Lerobot Converter {'-' * 10}") + self.logger.info(f"Processing Aria VRS dataset from {self.raw_path}") + self.logger.info(f"FPS: {self.fps}") + self.logger.info(f"Arm: {self.arm}") + self.logger.info(f"Image compressed: {self.image_compressed}") + self.logger.info(f"Save MP4: {self.save_mp4}") + + self._mp4_path = None # set from main() if --save-mp4 + self._mp4_writer = None # lazy-initialized in extract_episode() + self.episode_list = list(self.raw_path.glob("*.vrs")) + + self.feats_to_zarr_keys = {} + + if self.arm == "both": + self.embodiment = "aria_bimanual" + elif self.arm == "right": + self.embodiment = "aria_right_arm" + elif self.arm == "left": + self.embodiment = "aria_left_arm" + + def extract_episode( + self, + episode_path, + task_name: str = "", + task_description: str = "", + output_dir: Path = Path("."), + dataset_name: str = "", + chunk_timesteps: int = 100, + ): + """ + Extracts frames from an episode and saves them to the dataset. + Parameters + ---------- + episode_path : str + The path to the episode file. + task_description : str, optional + A description of the task associated with the episode (default is an empty string). + Returns + ------- + None + """ + episode_name = dataset_name + + episode_feats = AriaVRSExtractor.process_episode( + episode_path=episode_path, + arm=self.arm, + height=self.height, + width=self.width, + ) + numeric_data = {} + + image_data = {} + for key, value in episode_feats.items(): + if "images" in key: + if key in self.feats_to_zarr_keys: + image_data[self.feats_to_zarr_keys[key]] = value + else: + image_data[key] = value + else: + if key in self.feats_to_zarr_keys: + numeric_data[self.feats_to_zarr_keys[key]] = value + else: + numeric_data[key] = value + + zarr_path = ZarrWriter.create_and_write( + episode_path=output_dir / f"{episode_name}.zarr", + numeric_data=numeric_data if numeric_data else None, + image_data=image_data if image_data else None, + fps=self.fps, + embodiment=self.embodiment, + task_name=task_name, + task_description=task_description, + chunk_timesteps=chunk_timesteps, + ) + if self.save_mp4: + mp4_path = output_dir / f"{episode_name}.mp4" + images_tchw = np.asarray(image_data["images.front_1"]).transpose(0, 3, 1, 2) + save_preview_mp4(images_tchw, mp4_path, self.fps, half_res=False) + else: + mp4_path = None + return zarr_path, mp4_path + + +def main(args) -> None: + """Convert Eva HDF5 dataset to Zarr episodes. + + Parameters + ---------- + args : argparse.Namespace + Parsed command-line arguments (same shape as eva_to_lerobot). + """ + + try: + episode_hash = timestamp_ms_to_episode_hash(Path(args.raw_path).stem) + + converter = DatasetConverter( + raw_path=Path(args.raw_path), + fps=args.fps, + arm=args.arm, + image_compressed=args.image_compressed, + save_mp4=args.save_mp4, + debug=args.debug, + ) + + gc.collect() + ctypes.CDLL("libc.so.6").malloc_trim(0) + zarr_path, mp4_path = converter.extract_episode( + episode_path=Path(args.raw_path), + task_name=args.task_name, + task_description=args.task_description, + output_dir=Path(args.output_dir), + dataset_name=episode_hash, + ) + return zarr_path, mp4_path + except Exception: + logger.error( + "Error converting %s:\n%s", Path(args.raw_path), traceback.format_exc() + ) + return None + + +def argument_parse(): + parser = argparse.ArgumentParser( + description="Convert Aria VRS dataset to LeRobot-Robomimic hybrid and push to Hugging Face hub." + ) + + # Required arguments + parser.add_argument( + "--raw-path", + type=Path, + required=True, + help="Directory containing the vrs, vrs_json, and the processed mps folder.", + ) + parser.add_argument( + "--fps", type=int, required=True, help="Frames per second for the dataset." + ) + # Optional arguments + parser.add_argument( + "--task-name", + type=str, + default="Aria recorded dataset.", + help="Task name of the data.", + ) + parser.add_argument( + "--task-description", + type=str, + default="Aria recorded dataset.", + help="Task description of the data.", + ) + parser.add_argument( + "--arm", + type=str, + choices=["left", "right", "both"], + default="both", + help="Specify the arm for processing.", + ) + parser.add_argument( + "--image-compressed", + type=str2bool, + default=False, + help="Set to True if the images are compressed.", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory where the processed dataset will be stored. Defaults to LEROBOT_HOME.", + ) + parser.add_argument( + "--debug", action="store_true", help="Store only 2 episodes for debug purposes." + ) + + parser.add_argument( + "--save-mp4", + action="store_true", + help="If enabled, save a single half-resolution MP4 with all frames across episodes.", + ) + + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = argument_parse() + zarr_path, mp4_path = main(args) + print(zarr_path, mp4_path) diff --git a/egomimic/scripts/aria_process/aria_utils.py b/egomimic/scripts/aria_process/aria_utils.py index 599e9a47..04ee50d2 100644 --- a/egomimic/scripts/aria_process/aria_utils.py +++ b/egomimic/scripts/aria_process/aria_utils.py @@ -1,92 +1,66 @@ -import h5py +import os +from datetime import datetime, timezone +from typing import Dict, List + +import cv2 import numpy as np -from projectaria_tools.core import calibration +import projectaria_tools.core.sophus as sp +import torch +import torch.nn.functional as F +from projectaria_tools.core import calibration, data_provider, mps +from projectaria_tools.core.mps.utils import ( + get_nearest_eye_gaze, + get_nearest_hand_tracking_result, + get_nearest_pose, +) +from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions from projectaria_tools.core.stream_id import StreamId -from scipy.spatial.transform import Rotation as R - -ROTATION_MATRIX = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) +from egomimic.utils.egomimicUtils import ( + INTRINSICS, + cam_frame_to_cam_pixels, +) +from egomimic.utils.pose_utils import T_rot_orientation -def build_camera_matrix(provider, pose_t): - T_world_device = pose_t.transform_world_device - rgb_stream_id = StreamId("214-1") - rgb_stream_label = provider.get_label_from_stream_id(rgb_stream_id) - device_calibration = provider.get_device_calibration() - calib = device_calibration.get_camera_calib(rgb_stream_label) - rgb_camera_calibration = calibration.get_linear_camera_calibration( - 480, - 640, - 133.25430222 * 2, - rgb_stream_label, - calib.get_transform_device_camera(), - ) +ROTATION_MATRIX = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) +T_ROT_CAM = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]) - # rgb_camera_calibration = device_calibration.get_camera_calib(rgb_stream_label) - T_device_rgb_camera = rgb_camera_calibration.get_transform_device_camera() - T_world_rgb_camera = (T_world_device @ T_device_rgb_camera).to_matrix() - return T_world_rgb_camera +STEP_DEFAULT = 3.0 +EPISODE_LENGTH = 100 +CHUNK_LENGTH_ACT = 100 -def undistort_to_linear(provider, stream_ids, raw_image, camera_label="rgb"): +def undistort_to_linear( + provider, stream_ids, raw_image, camera_label="rgb", height=480, width=640 +): camera_label = provider.get_label_from_stream_id(stream_ids[camera_label]) calib = provider.get_device_calibration().get_camera_calib(camera_label) + focal_length = 133.25430222 * (height / 240) warped = calibration.get_linear_camera_calibration( - 480, 640, 133.25430222 * 2, camera_label, calib.get_transform_device_camera() + height, width, focal_length, camera_label, calib.get_transform_device_camera() ) warped_image = calibration.distort_by_calibration(raw_image, warped, calib) warped_rot = np.rot90(warped_image, k=3) return warped_rot -def reproject_point(pose, provider): - ## cam_matrix := extrinsics - rgb_stream_id = StreamId("214-1") - rgb_stream_label = provider.get_label_from_stream_id(rgb_stream_id) - device_calibration = provider.get_device_calibration() - # point_pose_camera = cam_matrix @ pose_hom - # print(point_pose_camera) - calib = device_calibration.get_camera_calib(rgb_stream_label) - T_device_sensor = device_calibration.get_transform_device_sensor(rgb_stream_label) - point_position_camera = T_device_sensor.inverse() @ pose - - warped = calibration.get_linear_camera_calibration( - 480, 640, 133.25430222 * 2, "rgb", calib.get_transform_device_camera() - ) - point_position_pixel = warped.project(point_position_camera) - return point_position_pixel - - -def split_train_val_from_hdf5(hdf5_path, val_ratio): - with h5py.File(hdf5_path, "a") as file: - demo_keys = [key for key in file["data"].keys() if "demo" in key] - num_demos = len(demo_keys) - num_val = int(np.ceil(num_demos * val_ratio)) - - indices = np.arange(num_demos) - np.random.shuffle(indices) - - val_indices = indices[:num_val] - train_indices = indices[num_val:] - - train_mask = [f"demo_{i}" for i in train_indices] - val_mask = [f"demo_{i}" for i in val_indices] - - file.create_dataset("mask/train", data=np.array(train_mask, dtype="S")) - file.create_dataset("mask/valid", data=np.array(val_mask, dtype="S")) - - -def slam_to_rgb(provider): +def slam_to_rgb(provider, height=480, width=640): """ Get slam camera to rgb camera transform provider: vrs data provider """ + focal_length = 133.25430222 * (height / 240) device_calibration = provider.get_device_calibration() slam_id = StreamId("1201-1") slam_label = provider.get_label_from_stream_id(slam_id) slam_calib = device_calibration.get_camera_calib(slam_label) slam_camera = calibration.get_linear_camera_calibration( - 480, 640, 133.24530222 * 2, slam_label, slam_calib.get_transform_device_camera() + height, + width, + focal_length, + slam_label, + slam_calib.get_transform_device_camera(), ) T_device_slam_camera = ( slam_camera.get_transform_device_camera() @@ -96,7 +70,7 @@ def slam_to_rgb(provider): rgb_label = provider.get_label_from_stream_id(rgb_id) rgb_calib = device_calibration.get_camera_calib(rgb_label) rgb_camera = calibration.get_linear_camera_calibration( - 480, 640, 133.24530222 * 2, rgb_label, rgb_calib.get_transform_device_camera() + height, width, focal_length, rgb_label, rgb_calib.get_transform_device_camera() ) T_device_rgb_camera = ( rgb_camera.get_transform_device_camera() @@ -120,32 +94,820 @@ def compute_coordinate_frame(palm_pose, wrist_pose, palm_normal): return -1 * x_axis, y_axis, z_axis -def transform_coordinates(palm_pose, x_axis, y_axis, z_axis, transform): - palm_pose_h = np.append(palm_pose, 1) - x_axis_h = np.append(x_axis, 0) - y_axis_h = np.append(y_axis, 0) - z_axis_h = np.append(z_axis, 0) +def compute_orientation_rotation_matrix(palm_pose, wrist_pose, palm_normal): + x_axis = wrist_pose - palm_pose + x_axis = np.ravel(x_axis) / np.linalg.norm(x_axis) + z_axis = np.ravel(palm_normal) / np.linalg.norm(palm_normal) + y_axis = np.cross(x_axis, z_axis) + y_axis = np.ravel(y_axis) / np.linalg.norm(y_axis) + + x_axis = np.cross(z_axis, y_axis) + x_axis = np.ravel(x_axis) / np.linalg.norm(x_axis) + + rot_matrix = np.column_stack([-1 * x_axis, y_axis, z_axis]) + return rot_matrix + - # Apply SLAM-to-RGB transformation - transformed_palm_pose = (transform @ palm_pose_h)[:3] - transformed_x_axis = (transform @ x_axis_h)[:3] - transformed_y_axis = (transform @ y_axis_h)[:3] - transformed_z_axis = (transform @ z_axis_h)[:3] +def timestamp_ms_to_episode_hash(timestamp_ms: int) -> str: + """ + Convert UTC epoch milliseconds -> string like "2026-01-12-03-47-29-664000". + (microseconds are always 6 digits; last 3 digits will be 000 because input is ms) + """ + if not isinstance(timestamp_ms, int): + raise TypeError("timestamp_ms must be an int (UTC epoch milliseconds).") - # Apply additional rotation transpose - rot_T = ROTATION_MATRIX.T # Compute the transpose - final_palm_pose = rot_T @ transformed_palm_pose - final_x_axis = rot_T @ transformed_x_axis - final_y_axis = rot_T @ transformed_y_axis - final_z_axis = rot_T @ transformed_z_axis + dt = datetime.fromtimestamp(timestamp_ms / 1000.0, tz=timezone.utc) + return dt.strftime("%Y-%m-%d-%H-%M-%S-%f") - return final_palm_pose, final_x_axis, final_y_axis, final_z_axis + +def downsample_hwc_uint8_in_chunks( + images: np.ndarray, # (T,H,W,3) uint8 + out_hw=(240, 320), + chunk: int = 256, +) -> np.ndarray: + assert images.dtype == np.uint8 and images.ndim == 4 and images.shape[-1] == 3 + T, H, W, C = images.shape + outH, outW = out_hw + + out = np.empty((T, outH, outW, 3), dtype=np.uint8) + + for s in range(0, T, chunk): + e = min(s + chunk, T) + x = ( + torch.from_numpy(images[s:e]).permute(0, 3, 1, 2).to(torch.float32) / 255.0 + ) # (B,3,H,W) + x = F.interpolate(x, size=(outH, outW), mode="bilinear", align_corners=False) + x = (x * 255.0).clamp(0, 255).to(torch.uint8) # (B,3,outH,outW) + out[s:e] = x.permute(0, 2, 3, 1).cpu().numpy() + del x + + return out + + +def quat_translation_swap(quat_translation: np.ndarray) -> np.ndarray: + """ + Swap the quaternion and translation in a (N, 7) array. + Parameters + ---------- + quat_translation : np.ndarray + (N, 7) array of quaternion and translation + Returns + ------- + np.ndarray: + (N, 7) array of translation and quaternion + """ + return np.concatenate( + (quat_translation[..., 4:7], quat_translation[..., 0:4]), axis=-1 + ) -def coordinate_frame_to_ypr(x_axis, y_axis, z_axis): - rot_matrix = np.column_stack([x_axis, y_axis, z_axis]) - rotation = R.from_matrix(rot_matrix) - euler_ypr = rotation.as_euler("ZYX", degrees=False) - if np.isnan(euler_ypr).any(): - euler_ypr = np.zeros_like(euler_ypr) - return euler_ypr +class AriaVRSExtractor: + TAGS = ["aria", "robotics", "vrs"] + + @staticmethod + def process_episode(episode_path, arm: str, low_res=False, height=480, width=640): + """ + Extracts all feature keys from a given episode and returns as a dictionary + Parameters + ---------- + episode_path : str or Path + Path to the VRS file containing the episode data. + arm : str + String for which arm to add data for + Returns + ------- + episode_feats : dict + Dictionary mapping keys in the episode to episode features, for example: + hand. : (world frame) (6D per arm) + hand. : (world frame) (3 cartesian + 4 quaternion + 63 dim (21 keypoints) per arm) + images. : + head_pose : (world frame) + + #TODO: Add metadata to be a nested dict + + """ + episode_feats = dict() + + # file setup and opening + root_dir = episode_path.parent + + mps_sample_path = os.path.join(root_dir, ("mps_" + episode_path.stem + "_vrs")) + + hand_tracking_results_path = os.path.join( + mps_sample_path, "hand_tracking", "hand_tracking_results.csv" + ) + + closed_loop_pose_path = os.path.join( + mps_sample_path, "slam", "closed_loop_trajectory.csv" + ) + + eye_gaze_path = os.path.join( + mps_sample_path, "eye_gaze", "general_eye_gaze.csv" + ) + use_eye_gaze = os.path.exists(eye_gaze_path) + # TODO: in the future might write to sql on the failure due to mps processing failures + if not os.path.exists(hand_tracking_results_path): + raise FileNotFoundError( + f"Hand tracking results file not found at {hand_tracking_results_path}" + ) + if not os.path.exists(closed_loop_pose_path): + raise FileNotFoundError( + f"Closed loop pose file not found at {closed_loop_pose_path}" + ) + + vrs_reader = data_provider.create_vrs_data_provider(str(episode_path)) + + hand_tracking_results = mps.hand_tracking.read_hand_tracking_results( + hand_tracking_results_path + ) + + closed_loop_traj = mps.read_closed_loop_trajectory(closed_loop_pose_path) + if use_eye_gaze: + eye_gaze_results = mps.read_eyegaze(eye_gaze_path) + + time_domain: TimeDomain = TimeDomain.DEVICE_TIME + + stream_ids: Dict[str, StreamId] = { + "rgb": StreamId("214-1"), + "slam-left": StreamId("1201-1"), + "slam-right": StreamId("1201-2"), + } + stream_timestamps_ns: Dict[str, List[int]] = { + key: vrs_reader.get_timestamps_ns(stream_id, time_domain) + for key, stream_id in stream_ids.items() + } + + rgb_to_device_T = slam_to_rgb( + vrs_reader, height=height, width=width + ) # aria sophus SE3 + + hand_cartesian_pose = AriaVRSExtractor.get_ee_pose( + world_device_T=closed_loop_traj, + stream_timestamps_ns=stream_timestamps_ns, + hand_tracking_results=hand_tracking_results, + arm=arm, + ) + + hand_keypoints_pose = AriaVRSExtractor.get_hand_keypoints( + world_device_T=closed_loop_traj, + stream_timestamps_ns=stream_timestamps_ns, + hand_tracking_results=hand_tracking_results, + arm=arm, + ) + + head_pose = AriaVRSExtractor.get_head_pose( + world_device_T=closed_loop_traj, + device_rgb_T=rgb_to_device_T.inverse(), + stream_timestamps_ns=stream_timestamps_ns, + ) + if use_eye_gaze: + eye_gaze = AriaVRSExtractor.get_eye_gaze( + eye_gaze_results=eye_gaze_results, + stream_timestamps_ns=stream_timestamps_ns, + ) + + images = AriaVRSExtractor.get_images( + vrs_reader=vrs_reader, + stream_ids=stream_ids, + stream_timestamps_ns=stream_timestamps_ns, + height=height, + width=width, + ) + + if low_res: + images = downsample_hwc_uint8_in_chunks( + images, out_hw=(240, 320), chunk=256 + ) + + rgb_timestamps_ns = np.array(stream_timestamps_ns["rgb"]) + print(f"[DEBUG] LENGTH BEFORE CLEANING: {len(hand_cartesian_pose)}") + mask_data = [images, rgb_timestamps_ns] + filter_mask_data = [hand_cartesian_pose, hand_keypoints_pose, head_pose] + if use_eye_gaze: + mask_data.append(eye_gaze) + + ( + output_filter_mask_data, + output_mask_data, + ) = AriaVRSExtractor.clean_data( + filter_mask_data=filter_mask_data, + mask_data=mask_data, + ) + + hand_cartesian_pose = output_filter_mask_data[0] + hand_keypoints_pose = output_filter_mask_data[1] + head_pose = output_filter_mask_data[2] + + if use_eye_gaze: + eye_gaze = output_mask_data[2] + + images = output_mask_data[0] + rgb_timestamps_ns = output_mask_data[1] + + print(f"[DEBUG] LENGTH AFTER CLEANING: {len(hand_cartesian_pose)}") + + use_left_hand = arm == "left" or arm == "both" + use_right_hand = arm == "right" or arm == "both" + if use_left_hand: + episode_feats["left.obs_ee_pose"] = hand_cartesian_pose[..., :7] + episode_feats["left.obs_keypoints"] = hand_keypoints_pose[ + ..., 7 : 7 + 21 * 3 + ] + episode_feats["left.obs_wrist_pose"] = hand_keypoints_pose[..., :7] + + if use_right_hand: + if arm == "both": + episode_feats["right.obs_ee_pose"] = hand_cartesian_pose[..., 7:] + episode_feats["right.obs_keypoints"] = hand_keypoints_pose[ + ..., 7 + 21 * 3 + 7 : 7 + 21 * 3 + 7 + 21 * 3 + ] + episode_feats["right.obs_wrist_pose"] = hand_keypoints_pose[ + ..., 7 + 21 * 3 : 7 + 21 * 3 + 7 + ] + elif arm == "right": + episode_feats["right.obs_ee_pose"] = hand_cartesian_pose[..., :7] + episode_feats["right.obs_keypoints"] = hand_keypoints_pose[ + ..., 7 : 7 + 21 * 3 + ] + episode_feats["right.obs_wrist_pose"] = hand_keypoints_pose[..., :7] + episode_feats["images.front_1"] = images + episode_feats["obs_head_pose"] = head_pose + if use_eye_gaze: + episode_feats["obs_eye_gaze"] = eye_gaze + episode_feats["obs_rgb_timestamps_ns"] = rgb_timestamps_ns + + return episode_feats + + @staticmethod + def clean_data(filter_mask_data, mask_data): + """ + Clean data + Parameters + ---------- + actions : np.arrayoses + pose : np.array + images : np.array + Returns + ------- + actions, pose, images : tuple of np.array + cleaned data + """ + mask = np.ones(len(filter_mask_data[0]), dtype=bool) + for pose in filter_mask_data: + bad_data_mask = np.any(pose >= 1e8, axis=1) + mask = mask & ~bad_data_mask + + for i in range(len(filter_mask_data)): + filter_mask_data[i] = filter_mask_data[i][mask] + for i in range(len(mask_data)): + mask_data[i] = mask_data[i][mask] + + return filter_mask_data, mask_data + + @staticmethod + def clean_data_projection( + actions, pose, images, arm, CHUNK_LENGTH=CHUNK_LENGTH_ACT + ): + """ + Clean data + Parameters + ---------- + actions : np.array + pose : np.array + images : np.array + Returns + ------- + actions, pose, images : tuple of np.array + cleaned data + """ + actions_copy = actions.copy() + if arm == "both": + actions_left = actions_copy[..., :3] + actions_right = actions_copy[..., 6:9] + actions_copy = np.concatenate((actions_left, actions_right), axis=-1) + else: + actions_copy = actions_copy[..., :3] + + ac_dim = actions_copy.shape[-1] + actions_flat = actions_copy.reshape(-1, 3) + + N, C, H, W = images.shape + + if H == 480: + intrinsics = INTRINSICS["base"] + elif H == 240: + intrinsics = INTRINSICS["base_half"] + px = cam_frame_to_cam_pixels(actions_flat, intrinsics) + px = px.reshape((-1, CHUNK_LENGTH, ac_dim)) + if ac_dim == 3: + bad_data_mask = ( + (px[:, :, 0] < 0) + | (px[:, :, 0] > (W)) + | (px[:, :, 1] < 0) + | (px[:, :, 1] > (H)) + ) + elif ac_dim == 6: + BUFFER = 0 + bad_data_mask = ( + (px[:, :, 0] < 0 - BUFFER) + | (px[:, :, 0] > (W) + BUFFER) + | (px[:, :, 1] < 0) + # | (px[:, :, 1] > 480 + BUFFER) + | (px[:, :, 3] < 0 - BUFFER) + | (px[:, :, 3] > (H) + BUFFER) + | (px[:, :, 4] < 0) + # | (px[:, :, 4] > 480 + BUFFER) + ) + + px_diff = np.diff(px, axis=1) + px_diff = np.concatenate( + (px_diff, np.zeros((px_diff.shape[0], 1, px_diff.shape[-1]))), axis=1 + ) + px_diff = np.abs(px_diff) + bad_data_mask = bad_data_mask | np.any(px_diff > 100, axis=2) + + bad_data_mask = np.any(bad_data_mask, axis=1) + + actions = actions[~bad_data_mask] + images = images[~bad_data_mask] + pose = pose[~bad_data_mask] + + return actions, pose, images + + @staticmethod + def get_images( + vrs_reader, + stream_ids: dict, + stream_timestamps_ns: dict, + height=480, + width=640, + ): + """ + Get RGB Image from VRS + Parameters + ---------- + vrs_reader : VRS Data Provider + Object that reads and obtains data from VRS + stream_ids : dict + maps sensor keys to a list of ids for Aria + stream_timestamps_ns : dict + dict that maps sensor keys to a list of nanosecond timestamps in device time + Returns + ------- + images : np.array + rgb images undistorted to 480x640x3 + """ + images = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + time_domain = TimeDomain.DEVICE_TIME + time_query_closest = TimeQueryOptions.CLOSEST + + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + + sample_frame = vrs_reader.get_image_data_by_time_ns( + stream_ids["rgb"], + query_timestamp, + time_domain, + time_query_closest, + ) + + image_t = undistort_to_linear( + vrs_reader, + stream_ids, + raw_image=sample_frame[0].to_numpy_array(), + height=height, + width=width, + ) + + images.append(image_t) + images = np.array(images) + return images + + @staticmethod + def get_hand_keypoints( + world_device_T, + stream_timestamps_ns: dict, + hand_tracking_results, + arm: str, + ): + """ + Get Hand Keypoints from VRS + Parameters + ---------- + world_device_T : np.array + Transform from world coordinates to ARIA camera frame + stream_timestamps_ns : dict + hand_tracking_results : dict + arm : str + arm to get hand keypoints for + Returns + ------- + hand_keypoints : np.array + hand_keypoints + """ + frame_length = len(stream_timestamps_ns["rgb"]) + + keypoints = [] + + use_left_hand = arm == "left" or arm == "both" + use_right_hand = arm == "right" or arm == "both" + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + hand_tracking_result_t = get_nearest_hand_tracking_result( + hand_tracking_results, query_timestamp + ) + world_device_T_t = get_nearest_pose(world_device_T, query_timestamp) + if world_device_T_t is not None: + world_device_T_t = world_device_T_t.transform_world_device + + right_confidence = getattr( + getattr(hand_tracking_result_t, "right_hand", None), "confidence", -1 + ) + left_confidence = getattr( + getattr(hand_tracking_result_t, "left_hand", None), "confidence", -1 + ) + left_obs_t = np.full(7 + 21 * 3, 1e9) + if ( + use_left_hand + and not left_confidence < 0 + and world_device_T_t is not None + ): + left_hand_keypoints = np.stack( + hand_tracking_result_t.left_hand.landmark_positions_device, axis=0 + ) + wrist_T = ( + hand_tracking_result_t.left_hand.transform_device_wrist + ) # Sophus SE3 + + world_wrist_T = world_device_T_t @ wrist_T + world_keypoints = ( + world_device_T_t @ left_hand_keypoints.T + ).T # keypoints are in device frame + + world_wrist_T = sp.SE3.from_matrix( + T_rot_orientation(world_wrist_T.to_matrix(), T_ROT_CAM) + ) + wrist_quat_and_translation = quat_translation_swap( + world_wrist_T.to_quat_and_translation() + ) + if wrist_quat_and_translation.ndim == 2: + wrist_quat_and_translation = wrist_quat_and_translation[0] + left_obs_t[:7] = wrist_quat_and_translation + left_obs_t[7:] = world_keypoints.flatten() + + right_obs_t = np.full(7 + 21 * 3, 1e9) + if ( + use_right_hand + and not right_confidence < 0 + and world_device_T_t is not None + ): + right_hand_keypoints = np.stack( + hand_tracking_result_t.right_hand.landmark_positions_device, axis=0 + ) + wrist_T = ( + hand_tracking_result_t.right_hand.transform_device_wrist + ) # Sophus SE3 + + world_wrist_T = world_device_T_t @ wrist_T + world_keypoints = ( + world_device_T_t @ right_hand_keypoints.T + ).T # keypoints are in device frame + + world_wrist_T = sp.SE3.from_matrix( + T_rot_orientation(world_wrist_T.to_matrix(), T_ROT_CAM) + ) + wrist_quat_and_translation = quat_translation_swap( + world_wrist_T.to_quat_and_translation() + ) + if wrist_quat_and_translation.ndim == 2: + wrist_quat_and_translation = wrist_quat_and_translation[0] + right_obs_t[:7] = wrist_quat_and_translation + right_obs_t[7:] = world_keypoints.flatten() + + if use_left_hand and use_right_hand: + keypoints_obs_t = np.concatenate((left_obs_t, right_obs_t), axis=-1) + elif use_left_hand: + keypoints_obs_t = left_obs_t + elif use_right_hand: + keypoints_obs_t = right_obs_t + else: + raise ValueError(f"Incorrect arm provided: {arm}") + keypoints.append(np.ravel(keypoints_obs_t)) + keypoints = np.array(keypoints) + return keypoints + + @staticmethod + def get_head_pose( + world_device_T, + device_rgb_T, + stream_timestamps_ns: dict, + ): + """ + Get Head Pose from VRS + Parameters + ---------- + world_device_T : np.array + Transform from world coordinates to ARIA camera frame + stream_timestamps_ns : dict + dict that maps sensor keys to a list of nanosecond timestamps in device time + + Returns + ------- + head_pose : np.array + head_pose + """ + head_pose = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + rgb_to_rgbprime_rot = np.eye(4) + rgb_to_rgbprime_rot[:3, :3] = ROTATION_MATRIX.T + rgb_to_rgbprime_T = sp.SE3.from_matrix(rgb_to_rgbprime_rot) + rgbprime_to_rgb_T = rgb_to_rgbprime_T.inverse() + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + world_device_T_t = get_nearest_pose(world_device_T, query_timestamp) + if world_device_T_t is not None: + world_device_T_t = world_device_T_t.transform_world_device + head_pose_obs_t = np.full(7, 1e9) + if world_device_T_t is not None: + world_rgb_T_t = world_device_T_t @ device_rgb_T @ rgbprime_to_rgb_T + head_pose_quat_and_translation = quat_translation_swap( + world_rgb_T_t.to_quat_and_translation() + ) + if head_pose_quat_and_translation.ndim == 2: + head_pose_quat_and_translation = head_pose_quat_and_translation[0] + head_pose_obs_t[:7] = head_pose_quat_and_translation + head_pose.append(np.ravel(head_pose_obs_t)) + head_pose = np.array(head_pose) + return head_pose + + @staticmethod + def get_eye_gaze( + eye_gaze_results, + stream_timestamps_ns: dict, + ): + gaze = [] + frame_length = len(stream_timestamps_ns["rgb"]) + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + gaze_info = get_nearest_eye_gaze(eye_gaze_results, query_timestamp) + if gaze_info is None: + gaze.append([1e9, 1e9, 1e9]) + else: + gaze.append([gaze_info.yaw, gaze_info.pitch, gaze_info.depth]) + gaze = np.array(gaze) + return gaze + + @staticmethod + def get_ee_pose( + world_device_T, + stream_timestamps_ns: dict, + hand_tracking_results, + arm: str, + ): + """ + Get EE Pose from VRS + Parameters + ---------- + world_device_T : np.array + Transform from world coordinates to ARIA camera frame + stream_timestamps_ns : dict + dict that maps sensor keys to a list of nanosecond timestamps in device time + hand_tracking_results : dict + dict that maps sensor keys to a list of hand tracking results + arm : str + arm to get hand keypoints for + Returns + ------- + ee_pose : np.array + ee_pose (6D per arm) + -1 if no hand tracking data is available + """ + ee_pose = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + use_left_hand = arm == "left" or arm == "both" + use_right_hand = arm == "right" or arm == "both" + + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + hand_tracking_result_t = get_nearest_hand_tracking_result( + hand_tracking_results, query_timestamp + ) + world_device_T_t = get_nearest_pose(world_device_T, query_timestamp) + if world_device_T_t is not None: + world_device_T_t = world_device_T_t.transform_world_device + + right_confidence = getattr( + getattr(hand_tracking_result_t, "right_hand", None), "confidence", -1 + ) + left_confidence = getattr( + getattr(hand_tracking_result_t, "left_hand", None), "confidence", -1 + ) + + left_obs_t = np.full(7, 1e9) + if ( + use_left_hand + and not left_confidence < 0 + and world_device_T_t is not None + ): + left_palm_pose = ( + hand_tracking_result_t.left_hand.get_palm_position_device() + ) + left_wrist_pose = ( + hand_tracking_result_t.left_hand.get_wrist_position_device() + ) + left_palm_normal = hand_tracking_result_t.left_hand.wrist_and_palm_normal_device.palm_normal_device + + left_rot_matrix = compute_orientation_rotation_matrix( + palm_pose=left_palm_pose, + wrist_pose=left_wrist_pose, + palm_normal=left_palm_normal, + ) + left_T_t = np.eye(4) + left_T_t[:3, :3] = left_rot_matrix + left_T_t[:3, 3] = left_palm_pose + left_T_t = sp.SE3.from_matrix(left_T_t) + left_T_t = world_device_T_t @ left_T_t + left_T_t = sp.SE3.from_matrix( + T_rot_orientation(left_T_t.to_matrix(), T_ROT_CAM) + ) + + left_quat_and_translation = quat_translation_swap( + left_T_t.to_quat_and_translation() + ) + if left_quat_and_translation.ndim == 2: + left_quat_and_translation = left_quat_and_translation[0] + left_obs_t[:7] = left_quat_and_translation + + right_obs_t = np.full(7, 1e9) + if ( + use_right_hand + and not right_confidence < 0 + and world_device_T_t is not None + ): + right_palm_pose = ( + hand_tracking_result_t.right_hand.get_palm_position_device() + ) + right_wrist_pose = ( + hand_tracking_result_t.right_hand.get_wrist_position_device() + ) + right_palm_normal = hand_tracking_result_t.right_hand.wrist_and_palm_normal_device.palm_normal_device + + right_rot_matrix = compute_orientation_rotation_matrix( + palm_pose=right_palm_pose, + wrist_pose=right_wrist_pose, + palm_normal=right_palm_normal, + ) + right_T_t = np.eye(4) + right_T_t[:3, :3] = right_rot_matrix + right_T_t[:3, 3] = right_palm_pose + right_T_t = sp.SE3.from_matrix(right_T_t) + right_T_t = world_device_T_t @ right_T_t + right_T_t = sp.SE3.from_matrix( + T_rot_orientation(right_T_t.to_matrix(), T_ROT_CAM) + ) + right_quat_and_translation = quat_translation_swap( + right_T_t.to_quat_and_translation() + ) + if right_quat_and_translation.ndim == 2: + right_quat_and_translation = right_quat_and_translation[0] + right_obs_t[:7] = right_quat_and_translation + + if use_left_hand and use_right_hand: + ee_pose_obs_t = np.concatenate((left_obs_t, right_obs_t), axis=-1) + elif use_left_hand: + ee_pose_obs_t = left_obs_t + elif use_right_hand: + ee_pose_obs_t = right_obs_t + else: + raise ValueError(f"Incorrect arm provided: {arm}") + ee_pose.append(np.ravel(ee_pose_obs_t)) + ee_pose = np.array(ee_pose) + return ee_pose + + @staticmethod + def get_cameras(rgb_camera_key: str): + """ + Returns a list of rgb keys + Parameters + ---------- + rgb_camera_key : str + + Returns + ------- + rgb_cameras : list of str + A list of keys corresponding to rgb_cameras in the dataset. + """ + + rgb_cameras = [rgb_camera_key] + return rgb_cameras + + @staticmethod + def get_state(state_key: str): + """ + Returns a list of state keys + Parameters + ---------- + state_key : str + + Returns + ------- + states : list of str + A list of keys corresponding to states in the dataset. + """ + + states = [state_key] + return states + + @staticmethod + def define_features( + episode_feats: dict, image_compressed: bool = True, encode_as_video: bool = True + ) -> tuple: + """ + Define features from episode_feats (output of process_episode), including a metadata section. + + Parameters + ---------- + episode_feats : dict + The output of the process_episode method, containing feature data. + image_compressed : bool, optional + Whether the images are compressed, by default True. + encode_as_video : bool, optional + Whether to encode images as video or as images, by default True. + + Returns + ------- + tuple of dict[str, dict] + A dictionary where keys are feature names and values are dictionaries + containing feature information such as dtype, shape, and dimension names, + and a separate dictionary for metadata (unused for now) + """ + features = {} + metadata = {} + for key, value in episode_feats.items(): + if isinstance(value, dict): # Handle nested dictionaries recursively + nested_features, nested_metadata = AriaVRSExtractor.define_features( + value, image_compressed, encode_as_video + ) + features.update( + { + f"{key}.{nested_key}": nested_value + for nested_key, nested_value in nested_features.items() + } + ) + features.update( + { + f"{key}.{nested_key}": nested_value + for nested_key, nested_value in nested_metadata.items() + } + ) + elif isinstance(value, np.ndarray): + dtype = str(value.dtype) + if "images" in key: + dtype = "video" if encode_as_video else "image" + if image_compressed: + decompressed_sample = cv2.imdecode(value[0], 1) + shape = ( + decompressed_sample.shape[1], + decompressed_sample.shape[0], + decompressed_sample.shape[2], + ) + else: + shape = value.shape[1:] # Skip the frame count dimension + dim_names = ["channel", "height", "width"] + elif "actions" in key and len(value[0].shape) > 1: + shape = value[0].shape + dim_names = ["chunk_length", "action_dim"] + dtype = f"prestacked_{str(value.dtype)}" + else: + shape = value[0].shape + dim_names = [f"dim_{i}" for i in range(len(shape))] + features[key] = { + "dtype": dtype, + "shape": shape, + "names": dim_names, + } + elif isinstance(value, torch.Tensor): + dtype = str(value.dtype) + if "actions" in key and len(tuple(value[0].size())) > 1: + dim_names = ["chunk_length", "action_dim"] + dtype = f"prestacked_{str(value.dtype)}" + else: + dim_names = [f"dim_{i}" for i in range(len(shape))] + shape = tuple(value[0].size()) + dim_names = [f"dim_{i}" for i in range(len(shape))] + features[key] = { + "dtype": dtype, + "shape": shape, + "names": dim_names, + } + else: + metadata[key] = { + "dtype": "metadata", + "value": value, + } + + return features, metadata