diff --git a/egomimic/rldb/zarr/action_chunk_transforms.py b/egomimic/rldb/zarr/action_chunk_transforms.py index ea0ce6ad..cfee367d 100644 --- a/egomimic/rldb/zarr/action_chunk_transforms.py +++ b/egomimic/rldb/zarr/action_chunk_transforms.py @@ -13,6 +13,7 @@ from __future__ import annotations from abc import abstractmethod +from typing import Literal import numpy as np from projectaria_tools.core.sophus import SE3 diff --git a/egomimic/rldb/zarr/test_zarr_read.py b/egomimic/rldb/zarr/test_zarr_read.py new file mode 100644 index 00000000..59d99ac9 --- /dev/null +++ b/egomimic/rldb/zarr/test_zarr_read.py @@ -0,0 +1,299 @@ +""" +Test script for verifying Zarr episodes written by eva_to_zarr.py. + +Usage: + # Test a single episode + python test_zarr_read.py --zarr-path /path/to/episode.zarr + + # Test all episodes in a directory + python test_zarr_read.py --zarr-dir /path/to/zarr_dataset/name/ +""" + +import argparse +import sys +from pathlib import Path + +import numpy as np +import torch + +from egomimic.rldb.zarr.zarr_dataset_multi import ( + LocalEpisodeResolver, + MultiDataset, + ZarrDataset, + ZarrEpisode, +) + +# Keys written by eva_to_zarr.py (values from DATASET_KEY_MAPPINGS) +EXPECTED_NUMERIC_KEYS = {"obs_eepose", "obs_joint", "cmd_eepose", "cmd_joint"} +EXPECTED_IMAGE_KEYS = {"front_img_1", "right_wrist_img", "left_wrist_img"} + +# key_map for ZarrDataset: maps output names -> zarr key + optional horizon +KEY_MAP = { + "obs/eepose": {"zarr_key": "obs_eepose"}, + "obs/joint": {"zarr_key": "obs_joint"}, + "cmd/eepose": {"zarr_key": "cmd_eepose"}, + "cmd/joint": {"zarr_key": "cmd_joint"}, + "img/front": {"zarr_key": "front_img_1"}, + "img/right_wrist": {"zarr_key": "right_wrist_img"}, + "img/left_wrist": {"zarr_key": "left_wrist_img"}, +} + + +# --------------------------------------------------------------------------- +# Single-episode checks +# --------------------------------------------------------------------------- + +def check_episode(zarr_path: Path) -> bool: + """ + Run all checks on a single zarr episode. Returns True if all pass. + """ + print(f"\n{'='*60}") + print(f"Checking: {zarr_path.name}") + print(f"{'='*60}") + ok = True + + # --- 1. ZarrEpisode: metadata and raw keys --- + try: + ep = ZarrEpisode(zarr_path) + except Exception as e: + print(f" [FAIL] Could not open ZarrEpisode: {e}") + return False + + total_frames = ep.metadata.get("total_frames", 0) + fps = ep.metadata.get("fps", "?") + embodiment = ep.metadata.get("embodiment", "?") + task = ep.metadata.get("task", "?") + features = ep.metadata.get("features", {}) + + print(f" total_frames : {total_frames}") + print(f" fps : {fps}") + print(f" embodiment : {embodiment}") + print(f" task : {task!r}") + print(f" stored keys : {sorted(features.keys())}") + + # Presence checks + present_keys = set(features.keys()) + for key in EXPECTED_NUMERIC_KEYS | EXPECTED_IMAGE_KEYS: + if key in present_keys: + dtype = features[key].get("dtype", "?") + print(f" [OK ] {key:30s} dtype={dtype}") + else: + print(f" [MISSING] {key}") + ok = False + + # Language annotations (optional — written when --example-language-annotations is set) + has_lang = "language_annotations" in ep._store + if has_lang: + n = ep._store["language_annotations"].shape[0] + print(f" [OK ] {'language_annotations':30s} count={n}") + else: + print(f" [INFO] language_annotations not present") + + if total_frames == 0: + print(" [FAIL] total_frames == 0") + ok = False + + # --- 2. Raw read via ZarrEpisode --- + print("\n Raw reads (first frame):") + for key in sorted(EXPECTED_NUMERIC_KEYS): + if key not in present_keys: + continue + try: + data = ep.read({key: (0, None)}) + arr = data[key] + print(f" [OK ] {key:30s} shape={arr.shape} dtype={arr.dtype}") + except Exception as e: + print(f" [FAIL] {key}: {e}") + ok = False + + for key in sorted(EXPECTED_IMAGE_KEYS): + if key not in present_keys: + continue + try: + data = ep.read({key: (0, None)}) + raw = data[key] + # raw is JPEG bytes for single-frame read + nbytes = len(raw) if isinstance(raw, (bytes, bytearray, memoryview)) else raw.nbytes + print(f" [OK ] {key:30s} bytes={nbytes}") + except Exception as e: + print(f" [FAIL] {key}: {e}") + ok = False + + # --- 3. ZarrDataset: frame-level __getitem__ --- + print("\n ZarrDataset __getitem__(0):") + # Only include keys actually present; add language_annotations if stored + key_map_filtered = {k: v for k, v in KEY_MAP.items() if v["zarr_key"] in present_keys} + if has_lang: + key_map_filtered["language"] = {"zarr_key": "language_annotations"} + try: + ds = ZarrDataset(zarr_path, key_map=key_map_filtered) + assert len(ds) == total_frames, f"len(ds)={len(ds)} != total_frames={total_frames}" + frame = ds[0] + for out_key, val in frame.items(): + if isinstance(val, torch.Tensor): + finite = torch.isfinite(val).all().item() if val.is_floating_point() else True + tag = "OK " if finite else "WARN" + print(f" [{tag}] {out_key:22s} shape={tuple(val.shape)} dtype={val.dtype}") + elif isinstance(val, np.ndarray): + print(f" [OK ] {out_key:22s} shape={val.shape} dtype={val.dtype} (numpy)") + elif isinstance(val, str): + preview = val[:60] + "..." if len(val) > 60 else val + print(f" [OK ] {out_key:22s} text={preview!r}") + else: + # language_annotations must always resolve to str, never list/bytes + if out_key == "language": + print(f" [FAIL] {out_key:22s} expected str, got {type(val).__name__}") + ok = False + else: + print(f" [OK ] {out_key:22s} type={type(val).__name__}") + except Exception as e: + print(f" [FAIL] {e}") + import traceback; traceback.print_exc() + ok = False + + # --- 4. Language annotation span contents + spot-check span matching --- + if has_lang: + print(f"\n Language annotation spans:") + try: + spans = ds._load_language_annotations() + for i, ann in enumerate(spans): + print(f" [{i}] text={ann.get('text','')!r} " + f"start={ann.get('start_idx')} end={ann.get('end_idx')}") + + # Spot-check: for each span, frame at start_idx and end_idx must + # return a str containing the annotation text. + print(f"\n Span matching spot-checks:") + for ann in spans: + text = ann.get("text", "") + s, e = int(ann.get("start_idx", 0)), int(ann.get("end_idx", 0)) + for label, fidx in [("start", s), ("end", e)]: + result = ds._annotation_text_for_frame(fidx) + assert isinstance(result, str), f"Expected str, got {type(result)}" + assert text in result, ( + f"Frame {fidx} ({label} of span): expected {text!r} in {result!r}" + ) + print(f" [OK ] frame={fidx:4d} ({label}) -> {result!r}") + + # A frame beyond all spans should return an empty string + beyond = total_frames # one past the last valid index + result = ds._annotation_text_for_frame(beyond) + assert result == "", f"Expected '' for out-of-span frame {beyond}, got {result!r}" + print(f" [OK ] frame={beyond:4d} (out-of-span) -> {result!r}") + + except AssertionError as e: + print(f" [FAIL] {e}") + ok = False + except Exception as e: + print(f" [FAIL] {e}") + import traceback; traceback.print_exc() + ok = False + + # --- 5. ZarrDataset: last frame --- + print(f"\n ZarrDataset __getitem__({total_frames - 1}) (last frame):") + try: + frame_last = ds[total_frames - 1] + for out_key, val in frame_last.items(): + if isinstance(val, torch.Tensor): + finite = torch.isfinite(val).all().item() if val.is_floating_point() else True + status = "OK " if finite else "WARN" + print(f" [{status}] {out_key:22s} shape={tuple(val.shape)}") + elif isinstance(val, str): + preview = val[:60] + "..." if len(val) > 60 else val + print(f" [OK ] {out_key:22s} text={preview!r}") + else: + print(f" [OK ] {out_key:22s} type={type(val).__name__}") + except Exception as e: + print(f" [FAIL] {e}") + ok = False + + status_str = "PASS" if ok else "FAIL" + print(f"\n Result: [{status_str}]") + return ok + + +# --------------------------------------------------------------------------- +# Directory-level checks via LocalEpisodeResolver + MultiDataset +# --------------------------------------------------------------------------- + +def check_directory(zarr_dir: Path) -> bool: + """ + Load all zarr episodes in zarr_dir via LocalEpisodeResolver and MultiDataset. + """ + print(f"\n{'='*60}") + print(f"Directory check: {zarr_dir}") + print(f"{'='*60}") + + zarr_paths = sorted(zarr_dir.glob("*.zarr")) + if not zarr_paths: + print(" No .zarr files found.") + return False + + print(f" Found {len(zarr_paths)} .zarr episode(s)") + + # Per-episode checks + results = {} + for p in zarr_paths: + results[p.name] = check_episode(p) + + # Summary + passed = sum(results.values()) + total = len(results) + print(f"\n{'='*60}") + print(f"Episode summary: {passed}/{total} passed") + for name, ok in sorted(results.items()): + tag = "PASS" if ok else "FAIL" + print(f" [{tag}] {name}") + + if total == 0: + return False + + # LocalEpisodeResolver + MultiDataset (train split) + print(f"\n LocalEpisodeResolver + MultiDataset test...") + try: + resolver = LocalEpisodeResolver( + folder_path=zarr_dir, + key_map=KEY_MAP, + ) + multi = MultiDataset._from_resolver( + resolver, + filters={}, # no extra filters beyond is_deleted=False default + mode="total", # use all episodes + ) + print(f" [OK ] MultiDataset total items: {len(multi)}") + sample = multi[0] + for k, v in sample.items(): + if isinstance(v, torch.Tensor): + print(f" {k:22s} shape={tuple(v.shape)}") + elif isinstance(v, str): + preview = v[:60] + "..." if len(v) > 60 else v + print(f" {k:22s} text={preview!r}") + else: + print(f" {k:22s} type={type(v).__name__}") + except Exception as e: + print(f" [FAIL] MultiDataset: {e}") + import traceback; traceback.print_exc() + + return passed == total + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Verify Zarr episodes from eva_to_zarr.py") + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--zarr-path", type=Path, help="Path to a single .zarr episode") + group.add_argument("--zarr-dir", type=Path, help="Directory containing .zarr episodes") + args = parser.parse_args() + + if args.zarr_path: + ok = check_episode(args.zarr_path) + else: + ok = check_directory(args.zarr_dir) + + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + main() diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 409204d2..98504af5 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -19,11 +19,15 @@ """ from __future__ import annotations - -import json import logging import os import random +from pathlib import Path +from tracemalloc import start +import pandas as pd +import numpy as np +import torch +import zarr import subprocess import tempfile from pathlib import Path @@ -523,7 +527,7 @@ def __getitem__(self, idx): return data @classmethod - def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): + def _from_resolver(cls, resolver: EpisodeResolver, sync_from_s3: bool = False, filters: dict | None = None, **kwargs): """ create a MultiDataset from an EpisodeResolver. @@ -538,9 +542,6 @@ def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): """ # TODO add key_map and transform pass to children - sync_from_s3 = kwargs.pop("sync_from_s3", False) - filters = kwargs.pop("filters", {}) or {} - if isinstance(resolver, LocalEpisodeResolver): resolved = resolver.resolve( sync_from_s3=sync_from_s3, @@ -592,7 +593,6 @@ def init_episode(self): # Detect JPEG-encoded image keys from metadata self._image_keys = self._detect_image_keys() - self._json_keys = self._detect_json_keys() def _detect_image_keys(self) -> set[str]: """ @@ -705,12 +705,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") # data[k] = torch.from_numpy(np.transpose(decoded, (2, 0, 1))).to(torch.float32) / 255.0 data[k] = np.transpose(decoded, (2, 0, 1)) / 255.0 - elif zarr_key in self._json_keys: - if isinstance(data[k], np.ndarray): - data[k] = [self._decode_json_entry(v) for v in data[k]] - else: - data[k] = self._decode_json_entry(data[k]) - + # Convert all numpy arrays in data to torch tensors # TODO add the transform list code here diff --git a/egomimic/robot/kinematics.py b/egomimic/robot/kinematics.py index d40f2a90..b0391ce3 100644 --- a/egomimic/robot/kinematics.py +++ b/egomimic/robot/kinematics.py @@ -165,8 +165,12 @@ def __init__( # Load MuJoCo model try: self.model = mujoco.MjModel.from_xml_path(self.model_path) - except Exception: + except Exception as e: # If direct loading fails, try creating a scene XML + print(f"Loading from URDF: {self.model_path} failed") + print(f"Error: {e}") + import traceback + traceback.print_exc() self.model = self._create_mujoco_model_from_urdf(model_path) self.data = mujoco.MjData(self.model) diff --git a/egomimic/scripts/aria_process/aria-cluster.yaml b/egomimic/scripts/aria_process/aria-cluster.yaml index 83f871d1..01c3e253 100644 --- a/egomimic/scripts/aria_process/aria-cluster.yaml +++ b/egomimic/scripts/aria_process/aria-cluster.yaml @@ -11,9 +11,6 @@ auth: ssh_user: ubuntu ssh_private_key: ~/.ssh/rldb-base-pem.pem -file_mounts: - "/home/ubuntu/EgoVerse": "/home/elmo/Documents/projects/EgoVerse" - rsync_exclude: - "emimic/" - ".git/" @@ -22,7 +19,7 @@ rsync_exclude: - "logs/" - "*.parquet" -max_workers: 60 +max_workers: 140 idle_timeout_minutes: 5 available_node_types: @@ -31,10 +28,12 @@ available_node_types: InstanceType: t3a.2xlarge KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 BlockDeviceMappings: - DeviceName: /dev/sda1 Ebs: { VolumeSize: 50 } - resources: { CPU: 4 } + resources: { CPU: 8 } min_workers: 0 max_workers: 0 @@ -43,24 +42,28 @@ available_node_types: InstanceType: r6a.2xlarge KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 BlockDeviceMappings: - DeviceName: /dev/sda1 - Ebs: { VolumeSize: 50 } - resources: { CPU: 8, aria_small: 1 } + Ebs: { VolumeSize: 150 } + resources: { CPU: 2, aria_small: 1 } min_workers: 0 - max_workers: 50 + max_workers: 120 worker_big: node_config: InstanceType: r6a.8xlarge + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac BlockDeviceMappings: - DeviceName: /dev/sda1 - Ebs: { VolumeSize: 100 } # often worth bumping - resources: { CPU: 32, aria_big: 1 } + Ebs: { VolumeSize: 200 } # often worth bumping + resources: { CPU: 8, aria_big: 1 } min_workers: 0 - max_workers: 10 + max_workers: 7 head_node_type: head_node @@ -101,8 +104,6 @@ initialization_commands: pip3 install --no-input -e . setup_commands: - - sudo mkdir -p /mnt/raw /mnt/processed - - | chmod +x ~/EgoVerse/egomimic/utils/aws/setup_secret.sh R2_SECRET_NAME=r2/rldb/credentials DB_SECRET_NAME=rds/appdb/appuser REGION=us-east-2 \ @@ -116,7 +117,7 @@ head_setup_commands: | grep -v ray_worker_gaurdrails.py \ | grep -v ray_worker_gaurdrails.lock ; \ echo 'CRON_TZ=America/New_York'; \ - echo '0 20 * * * flock -n /tmp/run_aria_conversion.lock /bin/bash -lc "set -a; . /home/ubuntu/.egoverse_env; set +a; /usr/bin/python3 ~/EgoVerse/egomimic/scripts/aria_process/run_aria_conversion.py --skip-if-done" >> ~/aria_conversion.log 2>&1'; \ + echo '0 20 * * * flock -n /tmp/run_aria_conversion.lock /bin/bash -lc "set -a; . /home/ubuntu/.egoverse_env; set +a; /usr/bin/python3 ~/EgoVerse/egomimic/scripts/aria_process/run_aria_conversion.py --skip-if-done --debug" >> ~/aria_conversion.log 2>&1'; \ echo '*/10 * * * * flock -n /tmp/ray_worker_gaurdrails.lock /usr/bin/python3 /home/ubuntu/EgoVerse/egomimic/utils/aws/budget_guardrails/ray_worker_gaurdrails.py >> /home/ubuntu/ray_worker_gaurdrails.log 2>&1') \ | crontab - || true diff --git a/egomimic/scripts/aria_process/aria_helper.py b/egomimic/scripts/aria_process/aria_helper.py index 362d13b6..14aa7b9e 100644 --- a/egomimic/scripts/aria_process/aria_helper.py +++ b/egomimic/scripts/aria_process/aria_helper.py @@ -2,21 +2,60 @@ from types import SimpleNamespace # import the real entry-point once -from egomimic.scripts.aria_process.aria_to_lerobot import main as aria_main +# from egomimic.scripts.aria_process.aria_to_lerobot import main as aria_main +from egomimic.scripts.aria_process.aria_to_zarr import main as aria_zarr_main -def lerobot_job( +# def lerobot_job( +# *, +# raw_path: str | Path, +# output_dir: str | Path, +# dataset_name: str, +# arm: str, +# description: str = "", +# ) -> None: +# """ +# Convert one trio to a LeRobot dataset. + +# Only the five arguments below are variable; everything else is fixed. +# """ +# raw_path = Path(raw_path).expanduser().resolve() +# output_dir = Path(output_dir).expanduser().resolve() + +# args = SimpleNamespace( +# raw_path=raw_path, +# output_dir=output_dir, +# name=dataset_name, +# arm=arm, +# description=description, +# # hard-wired defaults you specified +# dataset_repo_id=f"rpuns/{dataset_name}", +# fps=30, +# video_encoding=False, +# push=False, +# prestack=True, +# image_compressed=False, +# save_mp4=True, +# private=False, +# license="apache-2.0", +# nproc=16, +# nthreads=2, +# debug=False, +# benchmark=False +# ) + +# aria_main(args) + +def zarr_job( *, raw_path: str | Path, output_dir: str | Path, dataset_name: str, arm: str, description: str = "", -) -> None: +) -> tuple[Path, Path] | None: """ - Convert one trio to a LeRobot dataset. - - Only the five arguments below are variable; everything else is fixed. + Convert one trio to a Zarr dataset. """ raw_path = Path(raw_path).expanduser().resolve() output_dir = Path(output_dir).expanduser().resolve() @@ -24,23 +63,17 @@ def lerobot_job( args = SimpleNamespace( raw_path=raw_path, output_dir=output_dir, - name=dataset_name, arm=arm, - description=description, - # hard-wired defaults you specified - dataset_repo_id=f"rpuns/{dataset_name}", fps=30, - video_encoding=False, - push=False, - prestack=True, image_compressed=False, - save_mp4=True, - private=False, - license="apache-2.0", + video_encoding=False, nproc=16, nthreads=2, debug=False, benchmark=False, + save_mp4=True, + description=description, + dataset_name=dataset_name ) - aria_main(args) + return aria_zarr_main(args) \ No newline at end of file diff --git a/egomimic/scripts/aria_process/aria_to_zarr.py b/egomimic/scripts/aria_process/aria_to_zarr.py new file mode 100644 index 00000000..fe403a3f --- /dev/null +++ b/egomimic/scripts/aria_process/aria_to_zarr.py @@ -0,0 +1,1650 @@ +import argparse +from datetime import datetime, timezone +import logging +import os +from pathlib import Path +import shutil +import traceback +from typing import Any +from egomimic.rldb.zarr.zarr_writer import ZarrWriter +from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME +import cv2 +import h5py +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +import torch +import gc, ctypes +from enum import Enum +import projectaria_tools.core.sophus as sp + +from egomimic.utils.egomimicUtils import ( + prep_frame, + start_ffmpeg_mp4, + str2bool, + cam_frame_to_cam_pixels, + INTRINSICS, + interpolate_keys, + interpolate_arr, + interpolate_arr_euler, + transform_to_pose, + pose_to_transform, +) + +from projectaria_tools.core.calibration import CameraCalibration, DeviceCalibration +from projectaria_tools.core.sensor_data import TimeDomain, TimeQueryOptions + +from projectaria_tools.core import data_provider, mps + +from projectaria_tools.core.mps.utils import get_nearest_hand_tracking_result + +from projectaria_tools.core.mps.utils import ( + filter_points_from_confidence, + get_gaze_vector_reprojection, + get_nearest_eye_gaze, + get_nearest_pose, +) +from projectaria_tools.core.stream_id import StreamId + +from aria_utils import ( + build_camera_matrix, + compute_orientation_rotation_matrix, + undistort_to_linear, + slam_to_rgb, +) + +from egomimic.rldb.utils import EMBODIMENT + +import time + +import numpy as np + +import torch +import torch.nn.functional as F + +from scipy.spatial.transform import Rotation as R +import subprocess +import re +import threading +from contextlib import contextmanager +import psutil + +_root = psutil.Process(os.getpid()) + +def _proc_rss_mb(p: psutil.Process) -> float: + return p.memory_info().rss / (1024 ** 2) + +def cgroup_memory_peak_mb() -> float | None: + # cgroup v2 + candidates = [ + "/sys/fs/cgroup/memory.peak", + "/sys/fs/cgroup/memory.max_usage_in_bytes", # older v1 + ] + for p in candidates: + if os.path.exists(p): + try: + with open(p, "r") as f: + return int(f.read().strip()) / (1024 ** 2) + except (OSError, ValueError): + pass + return None + + +def _read_smaps_rollup_kb(pid: int) -> dict[str, int]: + out = {} + path = f"/proc/{pid}/smaps_rollup" + with open(path, "r") as f: + for line in f: + if ":" not in line: + continue + k, v = line.split(":", 1) + v = v.strip().split() + if len(v) >= 2 and v[1] == "kB": + out[k] = int(v[0]) + return out + +def tree_pss_mb() -> float: + procs = [_root] + try: + procs += _root.children(recursive=True) + except psutil.Error: + pass + + total_kb = 0 + for p in procs: + try: + d = _read_smaps_rollup_kb(p.pid) + if "Pss" in d: + total_kb += d["Pss"] + else: + # fallback + total_kb += p.memory_info().rss // 1024 + except Exception: + pass + return total_kb / 1024.0 + +def tree_mem_mb(include_children: bool = True, use_uss: bool = True) -> float: + root = psutil.Process(os.getpid()) + procs = [root] + if include_children: + try: + procs += root.children(recursive=True) + except Exception: + pass + + total = 0 + for p in procs: + try: + if use_uss and hasattr(p, "memory_full_info"): + total += p.memory_full_info().uss + else: + total += p.memory_info().rss + except Exception: + pass + return total / (1024 ** 2) + +class _Sampler: + def __init__(self, interval_s: float = 0.025): + self.interval_s = interval_s + self.ts = [] + self.mbs = [] + self._stop = threading.Event() + self._t = None + self._errored = False + + def start(self): + self._t = threading.Thread(target=self._run, daemon=True) + self._t.start() + + def _run(self): + t0 = time.time() + while not self._stop.is_set(): + t = time.time() - t0 + try: + mb = tree_pss_mb() + except Exception: + self._errored = True + time.sleep(self.interval_s) + continue + self.ts.append(t) + self.mbs.append(mb) + time.sleep(self.interval_s) + + def stop(self): + self._stop.set() + if self._t is not None: + self._t.join() + + +@contextmanager +def mem_section(name: str, sample_interval_s: float = 0.2, plot: bool = True, enabled: bool = False): + if not enabled: + yield + return + + start = tree_pss_mb() + sampler = _Sampler(interval_s=sample_interval_s) + sampler.start() + t0 = time.time() + try: + yield + finally: + sampler.stop() + end = tree_pss_mb() + dt = time.time() - t0 + + peak = max(sampler.mbs) if sampler.mbs else end + print(f"[{name}] end={end:.2f} MB delta={end-start:+.2f} MB peak={peak:.2f} MB time={dt:.2f}s") + + if plot and sampler.mbs and sampler.ts: + import matplotlib.pyplot as plt + n = min(len(sampler.ts), len(sampler.mbs)) + if n > 1: + plt.plot(sampler.ts[:n], sampler.mbs[:n]) + plt.xlabel("time (s)") + plt.ylabel("tree RSS (MB)") + plt.tight_layout() + plt.savefig(f"{_safe_name(name)}.png", dpi=150) + plt.close() + +def _safe_name(s: str) -> str: + return re.sub(r"[^a-zA-Z0-9._-]+", "_", s).strip("_") + +## CHANGE THIS TO YOUR DESIRED CACHE FOR HF +os.environ["HF_HOME"] = "~/.cache/huggingface" + +HORIZON_DEFAULT = 10 +STEP_DEFAULT = 3.0 +EPISODE_LENGTH = 100 +CHUNK_LENGTH_ACT = 100 + +ROTATION_MATRIX = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + + +# NOTE: Replaced by transform ee_pose +# def transform_actions(actions): +# if actions.shape[-1] == 3: +# actions[..., 0] *= -1 # Multiply x by -1 +# actions[..., 1] *= -1 # Multiply y by -1 +# elif actions.shape[-1] == 6: +# actions[..., 0] *= -1 # Multiply x by -1 for first set +# actions[..., 1] *= -1 # Multiply y by -1 for first set +# actions[..., 3] *= -1 # Multiply x by -1 for second set +# actions[..., 4] *= -1 # Multiply y by -1 for second set +# return actions + +PERMUTE = np.array([[0,0,1], [1,0,0], [0,1,0]]) + +def SE3_permute_rot(T: np.ndarray) -> np.ndarray: + """ + Permute the rotation matrix of a SE(3) transformation. + """ + rot = T[:3, :3] + rot = rot @ PERMUTE + T[:3, :3] = rot + return T + +def timestamp_ms_to_episode_hash(timestamp_ms: int) -> str: + """ + Convert UTC epoch milliseconds -> string like "2026-01-12-03-47-29-664000". + (microseconds are always 6 digits; last 3 digits will be 000 because input is ms) + """ + if not isinstance(timestamp_ms, int): + raise TypeError("timestamp_ms must be an int (UTC epoch milliseconds).") + + dt = datetime.fromtimestamp(timestamp_ms / 1000.0, tz=timezone.utc) + return dt.strftime("%Y-%m-%d-%H-%M-%S-%f") + + + +def pose_tx_ty_tz_qx_qy_qz_qw_to_SE3(pose): + """ + pose: iterable [tx, ty, tz, qx, qy, qz, qw] (quat is x,y,z,w) + returns: (4,4) SE(3) homogeneous transform + """ + tx, ty, tz, qx, qy, qz, qw = map(float, pose) + + rot = R.from_quat([qx, qy, qz, qw]) # scipy expects [x, y, z, w] + T = np.eye(4, dtype=np.float64) + T[:3, :3] = rot.as_matrix() + T[:3, 3] = [tx, ty, tz] + return T + + +def downsample_hwc_uint8_in_chunks( + images: np.ndarray, # (T,H,W,3) uint8 + out_hw=(240, 320), + chunk: int = 256, +) -> np.ndarray: + assert images.dtype == np.uint8 and images.ndim == 4 and images.shape[-1] == 3 + T, H, W, C = images.shape + outH, outW = out_hw + + out = np.empty((T, outH, outW, 3), dtype=np.uint8) + + for s in range(0, T, chunk): + e = min(s + chunk, T) + x = torch.from_numpy(images[s:e]).permute(0, 3, 1, 2).to(torch.float32) / 255.0 # (B,3,H,W) + x = F.interpolate(x, size=(outH, outW), mode="bilinear", align_corners=False) + x = (x * 255.0).clamp(0, 255).to(torch.uint8) # (B,3,outH,outW) + out[s:e] = x.permute(0, 2, 3, 1).cpu().numpy() + del x + + return out + +def compute_camera_relative_pose(pose, cam_t_inv, cam_offset): + """ + pose (6,) : np.array + x y z y p r + cam_t_inv (4, 4) : np.array + camera intrinsics inverse of timestep t + cam_offset (4, 4) : np.array + camera intrinsics of offset + + returns pose_t (6,) : np.array + future pose in camera t frame x y z y p r + """ + T_offset_pose = pose_to_transform(pose) + undo_rotation = np.eye(4) + undo_rotation[:3, :3] = ROTATION_MATRIX + + T_unrotated = undo_rotation @ T_offset_pose + T_world = np.dot(cam_offset, T_unrotated) + T_camera = np.dot(cam_t_inv, T_world) + + redo_rotation = np.eye(4) + redo_rotation[:3, :3] = ROTATION_MATRIX.T + T_final = redo_rotation @ T_camera + + pose_t = transform_to_pose(T_final) + return pose_t + + + +def quat_translation_swap(quat_translation: np.ndarray) -> np.ndarray: + """ + Swap the quaternion and translation in a (N, 7) array. + Parameters + ---------- + quat_translation : np.ndarray + (N, 7) array of quaternion and translation + Returns + ------- + np.ndarray: + (N, 7) array of translation and quaternion + """ + return np.concatenate((quat_translation[..., 4:7], quat_translation[..., 0:4]), axis=-1) + +def get_hand_pose_in_camera_frame(hand_data, cam_t_inv, cam_offset, transform): + """ + Process a single hand's data to compute the 6-dof pose in the camera-t frame. + + Args: + hand_data: hand data from mps: + - palm_position_device + - wrist_position_device + - wrist_and_palm_normal_device.palm_normal_device + cam_t_inv (np.ndarray): Inverse transformation matrix for the camera at timestep t. + cam_offset (np.ndarray): Transformation matrix for the camera offset. + transform: The transform used in transform_coordinates. + + Returns: + np.ndarray: 6-dof pose (translation + Euler angles) in the camera-t frame. + Returns np.full(6, 1e9) if the palm position is not detected. + """ + if hand_data is None or not np.any(hand_data.get_palm_position_device()): + return np.full(6, 1e9) + + palm_pose = hand_data.get_palm_position_device() + wrist_pose = hand_data.get_wrist_position_device() + palm_normal = hand_data.wrist_and_palm_normal_device.palm_normal_device + + if hand_data.confidence < 0: + pose_offset = np.full(6, 1e9) + return pose_offset + + x_axis, y_axis, z_axis = compute_coordinate_frame( + palm_pose=palm_pose, wrist_pose=wrist_pose, palm_normal=palm_normal + ) + + palm_pose, x, y, z = transform_coordinates( + palm_pose=palm_pose, + x_axis=x_axis, + y_axis=y_axis, + z_axis=z_axis, + transform=transform, + ) + + palm_euler = coordinate_frame_to_ypr(x, y, z) + pose_offset = np.concatenate((palm_pose, palm_euler), axis=None) + + pose_offset_in_camera_t = compute_camera_relative_pose( + pose_offset, cam_t_inv=cam_t_inv, cam_offset=cam_offset + ) + return pose_offset_in_camera_t + + +class AriaVRSExtractor: + TAGS = ["aria", "robotics", "vrs"] + + @staticmethod + def process_episode(episode_path, arm: str, low_res=False, benchmark=False): + f""" + Extracts all feature keys from a given episode and returns as a dictionary + Parameters + ---------- + episode_path : str or Path + Path to the VRS file containing the episode data. + arm : str + String for which arm to add data for + Returns + ------- + episode_feats : dict + Dictionary mapping keys in the episode to episode features, for example: + hand. : (world frame) (6D per arm) + hand. : (world frame) (3 cartesian + 4 quaternion + 63 dim (21 keypoints) per arm) + images. : + head_pose : (world frame) + + #TODO: Add metadata to be a nested dict + + """ + episode_feats = dict() + + # file setup and opening + filename = episode_path.name + root_dir = episode_path.parent + + mps_sample_path = os.path.join(root_dir, ("mps_" + episode_path.stem + "_vrs")) + + hand_tracking_results_path = os.path.join( + mps_sample_path, "hand_tracking", "hand_tracking_results.csv" + ) + + closed_loop_pose_path = os.path.join( + mps_sample_path, "slam", "closed_loop_trajectory.csv" + ) + + + vrs_reader = data_provider.create_vrs_data_provider(str(episode_path)) + + hand_tracking_results = mps.hand_tracking.read_hand_tracking_results( + hand_tracking_results_path + ) + + closed_loop_traj = mps.read_closed_loop_trajectory(closed_loop_pose_path) + + device_calibration = vrs_reader.get_device_calibration() + + time_domain: TimeDomain = TimeDomain.DEVICE_TIME + time_query_closest: TimeQueryOptions = TimeQueryOptions.CLOSEST + + stream_ids: Dict[str, StreamId] = { + "rgb": StreamId("214-1"), + "slam-left": StreamId("1201-1"), + "slam-right": StreamId("1201-2"), + } + stream_labels: Dict[str, str] = { + key: vrs_reader.get_label_from_stream_id(stream_id) + for key, stream_id in stream_ids.items() + } + stream_timestamps_ns: Dict[str, List[int]] = { + key: vrs_reader.get_timestamps_ns(stream_id, time_domain) + for key, stream_id in stream_ids.items() + } + + mps_data_paths_provider = mps.MpsDataPathsProvider(mps_sample_path) + mps_data_paths = mps_data_paths_provider.get_data_paths() + mps_reader = mps.MpsDataProvider(mps_data_paths) + + rgb_to_device_T = slam_to_rgb(vrs_reader) # aria sophus SE3 + + # ee_pose + # TODO: this will be useful for the future - when we add rotation and other state keys + # TODO: understand what this is for (Elmo) + state_key = AriaVRSExtractor.get_state("ee_pose")[0] + + hand_cartesian_pose = AriaVRSExtractor.get_ee_pose( + world_device_T=closed_loop_traj, + stream_timestamps_ns=stream_timestamps_ns, + hand_tracking_results=hand_tracking_results, + arm=arm, + ) + + hand_keypoints_pose = AriaVRSExtractor.get_hand_keypoints( + world_device_T=closed_loop_traj, + stream_timestamps_ns=stream_timestamps_ns, + hand_tracking_results=hand_tracking_results, + arm=arm, + ) + + head_pose = AriaVRSExtractor.get_head_pose( + world_device_T=closed_loop_traj, + device_rgb_T=rgb_to_device_T.inverse(), + stream_timestamps_ns=stream_timestamps_ns, + ) + + # rgb_camera + # TODO: this will be useful for the future - when we add other camera modalities + camera_key = AriaVRSExtractor.get_cameras("front_img_1")[0] + + images = AriaVRSExtractor.get_images( + vrs_reader=vrs_reader, + stream_ids=stream_ids, + stream_timestamps_ns=stream_timestamps_ns, + benchmark=benchmark, + ) + + + if low_res: + images = downsample_hwc_uint8_in_chunks(images, out_hw=(240, 320), chunk=256) + + # with mem_section("process_episode.torch_from_numpy_permute", sample_interval_s=0.1, plot=False): + # images = torch.from_numpy(images).permute(0, 3, 1, 2).float() + + # if low_res: + # with mem_section("process_episode.interpolate", sample_interval_s=0.1, plot=False): + # images = F.interpolate( + # images, size=(240, 320), mode="bilinear", align_corners=False + # ) + + # with mem_section("process_episode.byte_numpy", sample_interval_s=0.1, plot=False): + # images = images.byte().numpy() + + + print(f"[DEBUG] LENGTH BEFORE CLEANING: {len(hand_cartesian_pose)}") + [hand_cartesian_pose, hand_keypoints_pose, head_pose], images = AriaVRSExtractor.clean_data( + poses=[hand_cartesian_pose, hand_keypoints_pose, head_pose], images=images + ) + # actions, pose, images = AriaVRSExtractor.clean_data_projection(actions=actions, pose=pose, images=images, arm=arm) + print(f"[DEBUG] LENGTH AFTER CLEANING: {len(hand_cartesian_pose)}") + + episode_feats["left.obs_ee_pose"] = hand_cartesian_pose[..., :7] + episode_feats["right.obs_ee_pose"] = hand_cartesian_pose[..., 7:] + episode_feats["left.obs_keypoints"] = hand_keypoints_pose[..., 7:7 + 21*3] + episode_feats["right.obs_keypoints"] = hand_keypoints_pose[..., 7 + 21*3 + 7: 7 + 21*3 + 7 + 21*3] + episode_feats["left.obs_wrist_pose"] = hand_keypoints_pose[..., :7] + episode_feats["right.obs_wrist_pose"] = hand_keypoints_pose[..., 7 + 21*3: 7 + 21*3 + 7] + episode_feats["images.front_1"] = images + episode_feats["obs_head_pose"] = head_pose + + return episode_feats + + @staticmethod + def clean_data(poses, images): + """ + Clean data + Parameters + ---------- + actions : np.array + pose : np.array + images : np.array + Returns + ------- + actions, pose, images : tuple of np.array + cleaned data + """ + mask_poses = np.ones(len(poses[0]), dtype=bool) + for pose in poses: + bad_data_mask = np.any(pose >= 1e8, axis=1) + mask_poses = mask_poses & ~bad_data_mask + + for i in range(len(poses)): + poses[i] = poses[i][mask_poses] + clean_images = images[mask_poses] + + return poses, clean_images + + + @staticmethod + def iter_images(episode_path, chunk_length=64, height=720, width=960, focal_mult=2): + """ + Iterate over images from VRS + Parameters + ---------- + vrs_reader : VRS Data Provider + Object that reads and obtains data from VRS + stream_ids : dict + maps sensor keys to a list of ids for Aria + stream_ids=stream_ids, + stream_timestamps_ns=stream_timestamps_ns, + benchmark=benchmark, height=height, width=width, focal_mult=focal_mult + """ + vrs_reader = data_provider.create_vrs_data_provider(str(episode_path)) + stream_ids: Dict[str, StreamId] = { + "rgb": StreamId("214-1"), + "slam-left": StreamId("1201-1"), + "slam-right": StreamId("1201-2"), + } + time_domain = TimeDomain.DEVICE_TIME + time_query_closest = TimeQueryOptions.CLOSEST + stream_timestamps_ns: Dict[str, List[int]] = { + key: vrs_reader.get_timestamps_ns(stream_id, time_domain) + for key, stream_id in stream_ids.items() + } + + images = [] + frame_length = len(stream_timestamps_ns["rgb"]) + num_batches = frame_length // chunk_length + + + for t in range(num_batches): + batch_images = [] + for i in range(chunk_length): + query_timestamp = stream_timestamps_ns["rgb"][t * chunk_length + i] + sample_frame = vrs_reader.get_image_data_by_time_ns( + stream_ids["rgb"], + query_timestamp, + time_domain, + time_query_closest, + ) + image_t = undistort_to_linear( + vrs_reader, stream_ids, raw_image=sample_frame[0].to_numpy_array(), height=height, width=width, focal_mult=focal_mult + ) + batch_images.append(image_t) + batch_images = np.array(batch_images) + yield batch_images + + + @staticmethod + def clean_data_projection( + actions, pose, images, arm, CHUNK_LENGTH=CHUNK_LENGTH_ACT + ): + """ + Clean data + Parameters + ---------- + actions : np.array + pose : np.array + images : np.array + Returns + ------- + actions, pose, images : tuple of np.array + cleaned data + """ + actions_copy = actions.copy() + if arm == "bimanual": + actions_left = actions_copy[..., :3] + actions_right = actions_copy[..., 6:9] + actions_copy = np.concatenate((actions_left, actions_right), axis=-1) + else: + actions_copy = actions_copy[..., :3] + + ac_dim = actions_copy.shape[-1] + actions_flat = actions_copy.reshape(-1, 3) + + N, C, H, W = images.shape + + if H == 480: + intrinsics = INTRINSICS["base"] + elif H == 240: + intrinsics = INTRINSICS["base_half"] + px = cam_frame_to_cam_pixels(actions_flat, intrinsics) + px = px.reshape((-1, CHUNK_LENGTH, ac_dim)) + if ac_dim == 3: + bad_data_mask = ( + (px[:, :, 0] < 0) + | (px[:, :, 0] > (W)) + | (px[:, :, 1] < 0) + | (px[:, :, 1] > (H)) + ) + elif ac_dim == 6: + BUFFER = 0 + bad_data_mask = ( + (px[:, :, 0] < 0 - BUFFER) + | (px[:, :, 0] > (W) + BUFFER) + | (px[:, :, 1] < 0) + # | (px[:, :, 1] > 480 + BUFFER) + | (px[:, :, 3] < 0 - BUFFER) + | (px[:, :, 3] > (H) + BUFFER) + | (px[:, :, 4] < 0) + # | (px[:, :, 4] > 480 + BUFFER) + ) + + px_diff = np.diff(px, axis=1) + px_diff = np.concatenate( + (px_diff, np.zeros((px_diff.shape[0], 1, px_diff.shape[-1]))), axis=1 + ) + px_diff = np.abs(px_diff) + bad_data_mask = bad_data_mask | np.any(px_diff > 100, axis=2) + + bad_data_mask = np.any(bad_data_mask, axis=1) + + actions = actions[~bad_data_mask] + images = images[~bad_data_mask] + pose = pose[~bad_data_mask] + + return actions, pose, images + + @staticmethod + def get_images( + vrs_reader, + stream_ids: dict, + stream_timestamps_ns: dict, + benchmark=False, + ): + """ + Get RGB Image from VRS + Parameters + ---------- + vrs_reader : VRS Data Provider + Object that reads and obtains data from VRS + stream_ids : dict + maps sensor keys to a list of ids for Aria + stream_timestamps_ns : dict + dict that maps sensor keys to a list of nanosecond timestamps in device time + Returns + ------- + images : np.array + rgb images undistorted to 480x640x3 + """ + images = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + time_domain = TimeDomain.DEVICE_TIME + time_query_closest = TimeQueryOptions.CLOSEST + + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + + sample_frame = vrs_reader.get_image_data_by_time_ns( + stream_ids["rgb"], + query_timestamp, + time_domain, + time_query_closest, + ) + + image_t = undistort_to_linear( + vrs_reader, stream_ids, raw_image=sample_frame[0].to_numpy_array() + ) + + + images.append(image_t) + with mem_section("get_images.list_to_numpy_array", sample_interval_s=0.1, plot=False, enabled=benchmark): + images = np.array(images) + return images + + @staticmethod + def get_hand_keypoints( + world_device_T, + stream_timestamps_ns: dict, + hand_tracking_results, + arm: str, + ): + """ + Get Hand Keypoints from VRS + Parameters + ---------- + world_device_T : np.array + Transform from world coordinates to ARIA camera frame + stream_timestamps_ns : dict + hand_tracking_results : dict + arm : str + arm to get hand keypoints for + Returns + ------- + hand_keypoints : np.array + hand_keypoints + """ + hand_keypoints = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + time_domain = TimeDomain.DEVICE_TIME + time_query_closest = TimeQueryOptions.CLOSEST + + ee_pose = [] + + use_left_hand = (arm == "left" or arm == "bimanual") + use_right_hand = (arm == "right" or arm == "bimanual") + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + hand_tracking_result_t = get_nearest_hand_tracking_result( + hand_tracking_results, query_timestamp + ) + world_device_T_t = get_nearest_pose(world_device_T, query_timestamp) + if world_device_T_t is not None: + world_device_T_t = world_device_T_t.transform_world_device + + right_confidence = getattr( + getattr(hand_tracking_result_t, "right_hand", None), "confidence", -1 + ) + left_confidence = getattr( + getattr(hand_tracking_result_t, "left_hand", None), "confidence", -1 + ) + left_obs_t = np.full(7 + 21 * 3, 1e9) + if use_left_hand and not left_confidence < 0 and world_device_T_t is not None: + left_hand_keypoints = np.stack(hand_tracking_result_t.left_hand.landmark_positions_device, axis=0) + wrist_T = hand_tracking_result_t.left_hand.transform_device_wrist # Sophus SE3 + + world_wrist_T = world_device_T_t @ wrist_T + world_keypoints = (world_device_T_t @ left_hand_keypoints.T).T # keypoints are in device frame + + world_wrist_T = sp.SE3.from_matrix(SE3_permute_rot(world_wrist_T.to_matrix())) + wrist_quat_and_translation = quat_translation_swap(world_wrist_T.to_quat_and_translation()) + if wrist_quat_and_translation.ndim == 2: + wrist_quat_and_translation = wrist_quat_and_translation[0] + left_obs_t[:7] = wrist_quat_and_translation + left_obs_t[7:] = world_keypoints.flatten() + + + right_obs_t = np.full(7 + 21 * 3, 1e9) + if use_right_hand and not right_confidence < 0 and world_device_T_t is not None: + right_hand_keypoints = np.stack(hand_tracking_result_t.right_hand.landmark_positions_device, axis=0) + wrist_T = hand_tracking_result_t.right_hand.transform_device_wrist # Sophus SE3 + + world_wrist_T = world_device_T_t @ wrist_T + world_keypoints = (world_device_T_t @ right_hand_keypoints.T).T # keypoints are in device frame + + world_wrist_T = sp.SE3.from_matrix(SE3_permute_rot(world_wrist_T.to_matrix())) + wrist_quat_and_translation = quat_translation_swap(world_wrist_T.to_quat_and_translation()) + if wrist_quat_and_translation.ndim == 2: + wrist_quat_and_translation = wrist_quat_and_translation[0] + right_obs_t[:7] = wrist_quat_and_translation + right_obs_t[7:] = world_keypoints.flatten() + + if use_left_hand and use_right_hand: + ee_pose_obs_t = np.concatenate((left_obs_t, right_obs_t), axis=-1) + elif use_left_hand: + ee_pose_obs_t = left_obs_t + elif use_right_hand: + ee_pose_obs_t = right_obs_t + else: + raise ValueError(f"Incorrect arm provided: {arm}") + ee_pose.append(np.ravel(ee_pose_obs_t)) + ee_pose = np.array(ee_pose) + return ee_pose + + @staticmethod + def get_head_pose( + world_device_T, + device_rgb_T, + stream_timestamps_ns: dict, + ): + """ + Get Head Pose from VRS + Parameters + ---------- + world_device_T : np.array + Transform from world coordinates to ARIA camera frame + stream_timestamps_ns : dict + dict that maps sensor keys to a list of nanosecond timestamps in device time + + Returns + ------- + head_pose : np.array + head_pose + """ + head_pose = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + time_domain = TimeDomain.DEVICE_TIME + time_query_closest = TimeQueryOptions.CLOSEST + rgb_to_rgbprime_rot = np.eye(4) + rgb_to_rgbprime_rot[:3, :3] = ROTATION_MATRIX.T + rgb_to_rgbprime_T = sp.SE3.from_matrix(rgb_to_rgbprime_rot) + rgbprime_to_rgb_T = rgb_to_rgbprime_T.inverse() + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + world_device_T_t = get_nearest_pose(world_device_T, query_timestamp) + if world_device_T_t is not None: + world_device_T_t = world_device_T_t.transform_world_device + head_pose_obs_t = np.full(7, 1e9) + if world_device_T_t is not None: + world_rgb_T_t = world_device_T_t @ device_rgb_T @ rgbprime_to_rgb_T + head_pose_quat_and_translation = quat_translation_swap(world_rgb_T_t.to_quat_and_translation()) + if head_pose_quat_and_translation.ndim == 2: + head_pose_quat_and_translation = head_pose_quat_and_translation[0] + head_pose_obs_t[:7] = head_pose_quat_and_translation + head_pose.append(np.ravel(head_pose_obs_t)) + head_pose = np.array(head_pose) + return head_pose + + @staticmethod + def get_ee_pose( + world_device_T, + stream_timestamps_ns: dict, + hand_tracking_results, + arm: str, + ): + """ + Get EE Pose from VRS + Parameters + ---------- + world_device_T : np.array + Transform from world coordinates to ARIA camera frame + stream_timestamps_ns : dict + dict that maps sensor keys to a list of nanosecond timestamps in device time + hand_tracking_results : dict + dict that maps sensor keys to a list of hand tracking results + arm : str + arm to get hand keypoints for + Returns + ------- + ee_pose : np.array + ee_pose (6D per arm) + -1 if no hand tracking data is available + """ + ee_pose = [] + frame_length = len(stream_timestamps_ns["rgb"]) + + time_domain = TimeDomain.DEVICE_TIME + time_query_closest = TimeQueryOptions.CLOSEST + + use_left_hand = (arm == "left" or arm == "bimanual") + use_right_hand = (arm == "right" or arm == "bimanual") + + + + for t in range(frame_length): + query_timestamp = stream_timestamps_ns["rgb"][t] + hand_tracking_result_t = get_nearest_hand_tracking_result( + hand_tracking_results, query_timestamp + ) + world_device_T_t = get_nearest_pose(world_device_T, query_timestamp) + if world_device_T_t is not None: + world_device_T_t = world_device_T_t.transform_world_device + device_world_T_t = world_device_T_t.inverse() + + right_confidence = getattr( + getattr(hand_tracking_result_t, "right_hand", None), "confidence", -1 + ) + left_confidence = getattr( + getattr(hand_tracking_result_t, "left_hand", None), "confidence", -1 + ) + + + + left_obs_t = np.full(7, 1e9) + if use_left_hand and not left_confidence < 0 and world_device_T_t is not None: + left_palm_pose = ( + hand_tracking_result_t.left_hand.get_palm_position_device() + ) + left_wrist_pose = ( + hand_tracking_result_t.left_hand.get_wrist_position_device() + ) + left_palm_normal = hand_tracking_result_t.left_hand.wrist_and_palm_normal_device.palm_normal_device + + left_rot_matrix = compute_orientation_rotation_matrix( + palm_pose=left_palm_pose, + wrist_pose=left_wrist_pose, + palm_normal=left_palm_normal, + ) + left_T_t = np.eye(4) + left_T_t[:3, :3] = left_rot_matrix + left_T_t[:3, 3] = left_palm_pose + left_T_t = sp.SE3.from_matrix(left_T_t) + left_T_t = world_device_T_t @ left_T_t + left_T_t = sp.SE3.from_matrix(SE3_permute_rot(left_T_t.to_matrix())) + + left_quat_and_translation = quat_translation_swap(left_T_t.to_quat_and_translation()) + if left_quat_and_translation.ndim == 2: + left_quat_and_translation = left_quat_and_translation[0] + left_obs_t[:7] = left_quat_and_translation + + right_obs_t = np.full(7, 1e9) + if use_right_hand and not right_confidence < 0 and world_device_T_t is not None: + right_palm_pose = ( + hand_tracking_result_t.right_hand.get_palm_position_device() + ) + right_wrist_pose = ( + hand_tracking_result_t.right_hand.get_wrist_position_device() + ) + right_palm_normal = hand_tracking_result_t.right_hand.wrist_and_palm_normal_device.palm_normal_device + + right_rot_matrix = compute_orientation_rotation_matrix( + palm_pose=right_palm_pose, + wrist_pose=right_wrist_pose, + palm_normal=right_palm_normal, + ) + right_T_t = np.eye(4) + right_T_t[:3, :3] = right_rot_matrix + right_T_t[:3, 3] = right_palm_pose + right_T_t = sp.SE3.from_matrix(right_T_t) + right_T_t = world_device_T_t @ right_T_t + right_T_t = sp.SE3.from_matrix(SE3_permute_rot(right_T_t.to_matrix())) + right_quat_and_translation = quat_translation_swap(right_T_t.to_quat_and_translation()) + if right_quat_and_translation.ndim == 2: + right_quat_and_translation = right_quat_and_translation[0] + right_obs_t[:7] = right_quat_and_translation + + if use_left_hand and use_right_hand: + ee_pose_obs_t = np.concatenate((left_obs_t, right_obs_t), axis=-1) + elif use_left_hand: + ee_pose_obs_t = left_obs_t + elif use_right_hand: + ee_pose_obs_t = right_obs_t + else: + raise ValueError(f"Incorrect arm provided: {arm}") + ee_pose.append(np.ravel(ee_pose_obs_t)) + ee_pose = np.array(ee_pose) + return ee_pose + + @staticmethod + def get_cameras(rgb_camera_key: str): + """ + Returns a list of rgb keys + Parameters + ---------- + rgb_camera_key : str + + Returns + ------- + rgb_cameras : list of str + A list of keys corresponding to rgb_cameras in the dataset. + """ + + rgb_cameras = [rgb_camera_key] + return rgb_cameras + + @staticmethod + def get_state(state_key: str): + """ + Returns a list of state keys + Parameters + ---------- + state_key : str + + Returns + ------- + states : list of str + A list of keys corresponding to states in the dataset. + """ + + states = [state_key] + return states + + @staticmethod + def iter_episode_frames( + episode_path: str | Path, + features: dict[str, dict], + image_compressed: bool, + arm: str, + prestack: bool = False, + benchmark: bool = False, + ): + episode_feats = AriaVRSExtractor.process_episode( + episode_path, arm=arm, benchmark=benchmark + ) + + episode_name = episode_path.name + # check if episode is timestamped + if not "-" in episode_name: + episode_name = timestamp_ms_to_episode_hash(int(episode_name)) + + try: + num_frames = next(iter(episode_feats["head_pose"].values())).shape[0] + + for frame_idx in range(num_frames): + frame = {} + + for feature_id, _info in features.items(): + if feature_id.startswith("observations."): + key = feature_id.split(".", 1)[-1] # "images.front_img_1" / "state.ee_pose" + value = episode_feats["observations"].get(key, None) + else: + value = episode_feats.get(feature_id, None) + + if value is None: + frame = None + break + + if isinstance(value, np.ndarray): + if "images" in feature_id: + if image_compressed: + img = cv2.imdecode(value[frame_idx], 1) # HWC BGR uint8 + frame[feature_id] = ( + torch.from_numpy(img) + .permute(2, 0, 1) + .contiguous() + ) # CHW uint8 + else: + frame[feature_id] = ( + torch.from_numpy(value[frame_idx]) + .permute(2, 0, 1) + .contiguous() + ) # HWC -> CHW + else: + frame[feature_id] = torch.from_numpy(value[frame_idx]) + elif isinstance(value, torch.Tensor): + frame[feature_id] = value[frame_idx] + else: + frame = None + break + + if frame is not None: + yield frame + finally: + del episode_feats + + + @staticmethod + def define_features( + episode_feats: dict, image_compressed: bool = True, encode_as_video: bool = True + ) -> tuple: + """ + Define features from episode_feats (output of process_episode), including a metadata section. + + Parameters + ---------- + episode_feats : dict + The output of the process_episode method, containing feature data. + image_compressed : bool, optional + Whether the images are compressed, by default True. + encode_as_video : bool, optional + Whether to encode images as video or as images, by default True. + + Returns + ------- + tuple of dict[str, dict] + A dictionary where keys are feature names and values are dictionaries + containing feature information such as dtype, shape, and dimension names, + and a separate dictionary for metadata (unused for now) + """ + features = {} + metadata = {} + for key, value in episode_feats.items(): + if isinstance(value, dict): # Handle nested dictionaries recursively + nested_features, nested_metadata = AriaVRSExtractor.define_features( + value, image_compressed, encode_as_video + ) + features.update( + { + f"{key}.{nested_key}": nested_value + for nested_key, nested_value in nested_features.items() + } + ) + features.update( + { + f"{key}.{nested_key}": nested_value + for nested_key, nested_value in nested_metadata.items() + } + ) + elif isinstance(value, np.ndarray): + dtype = str(value.dtype) + if "images" in key: + dtype = "video" if encode_as_video else "image" + if image_compressed: + decompressed_sample = cv2.imdecode(value[0], 1) + shape = ( + decompressed_sample.shape[1], + decompressed_sample.shape[0], + decompressed_sample.shape[2], + ) + else: + shape = value.shape[1:] # Skip the frame count dimension + dim_names = ["channel", "height", "width"] + elif "actions" in key and len(value[0].shape) > 1: + shape = value[0].shape + dim_names = ["chunk_length", "action_dim"] + dtype = f"prestacked_{str(value.dtype)}" + else: + shape = value[0].shape + dim_names = [f"dim_{i}" for i in range(len(shape))] + features[key] = { + "dtype": dtype, + "shape": shape, + "names": dim_names, + } + elif isinstance(value, torch.Tensor): + dtype = str(value.dtype) + if "actions" in key and len(tuple(value[0].size())) > 1: + dim_names = ["chunk_length", "action_dim"] + dtype = f"prestacked_{str(value.dtype)}" + else: + dim_names = [f"dim_{i}" for i in range(len(shape))] + shape = tuple(value[0].size()) + dim_names = [f"dim_{i}" for i in range(len(shape))] + features[key] = { + "dtype": dtype, + "shape": shape, + "names": dim_names, + } + else: + metadata[key] = { + "dtype": "metadata", + "value": value, + } + + return features, metadata + + +class DatasetConverter: + """ + A class to convert datasets to Lerobot format. + Parameters + ---------- + raw_path : Path or str + The path to the raw dataset. + dataset_repo_id : str + The repository ID where the dataset will be stored. + fps : int + Frames per second for the dataset. + arm : str, optional + The arm to process (e.g., 'left', 'right', or 'bimanual'), by default "". + encode_as_videos : bool, optional + Whether to encode images as videos, by default True. + image_compressed : bool, optional + Whether the images are compressed, by default True. + image_writer_processes : int, optional + Number of processes for writing images, by default 0. + image_writer_threads : int, optional + Number of threads for writing images, by default 0. + prestack : bool, optional + Whether to precompute action chunks, by default False. + Methods + ------- + extract_episode(episode_path, task_description='') + Extracts frames from a single episode and saves it with a description. + extract_episodes(episode_description='') + Extracts frames from all episodes and saves them with a description. + push_dataset_to_hub(dataset_tags=None, private=False, push_videos=True, license="apache-2.0") + Pushes the dataset to the Hugging Face Hub. + init_lerobot_dataset() + Initializes the Lerobot dataset. + """ + + def __init__( + self, + raw_path: Path | str, + fps: int, + arm: str = "", + encode_as_videos: bool = True, + image_compressed: bool = True, + image_writer_processes: int = 0, + image_writer_threads: int = 0, + debug: bool = False, + benchmark: bool = False, + ): + self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path) + self.fps = fps + self.arm = arm + self.image_compressed = image_compressed + self.image_writer_threads = image_writer_threads + self.image_writer_processes = image_writer_processes + self.encode_as_videos = encode_as_videos + self.benchmark = benchmark + if self.benchmark: + print(f"Benchmark mode enabled. This will plot the RAM usage of each section.") + + self.logger = logging.getLogger(self.__class__.__name__) + self.logger.setLevel(logging.INFO) + + # Add console handler + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s") + console_handler.setFormatter(formatter) + self.logger.addHandler(console_handler) + + self.logger.info(f"{'-' * 10} Aria VRS -> Lerobot Converter {'-' * 10}") + self.logger.info(f"Processing Aria VRS dataset from {self.raw_path}") + self.logger.info(f"FPS: {self.fps}") + self.logger.info(f"Arm: {self.arm}") + self.logger.info(f"Image compressed: {self.image_compressed}") + self.logger.info(f"Encoding images as videos: {self.encode_as_videos}") + self.logger.info(f"#writer processes: {self.image_writer_processes}") + self.logger.info(f"#writer threads: {self.image_writer_threads}") + + + self._mp4_path = None # set from main() if --save-mp4 + self._mp4_writer = None # lazy-initialized in extract_episode() + self.episode_list = list(self.raw_path.glob("*.vrs")) + + self.feats_to_zarr_keys = {} + + if debug: + self.episode_list = self.episode_list[:2] + + if self.arm == "bimanual": + self.embodiment = "aria_bimanual" + elif self.arm == "right": + self.embodiment = "aria_right_arm" + elif self.arm == "left": + self.embodiment = "aria_left_arm" + + def save_preview_mp4(self, image_frames: list[dict], output_path: Path, fps: int, image_compressed: bool): + """ + Save a single half-resolution, web-compatible MP4 using H.264 (libx264). + No fallbacks. Requires `ffmpeg` with libx264 on PATH. + + Each frame dict must contain: + 'observations.images.front_img_1' -> torch.Tensor (C,H,W) uint8 + """ + + imgs = image_frames + + # Compute half-res (force even dims for yuv420p) + C, H, W = imgs[0].shape + outW, outH = W // 2, H // 2 + if outW % 2: outW -= 1 + if outH % 2: outH -= 1 + if outW <= 0 or outH <= 0: + raise ValueError(f"[MP4] Invalid output size: {outW}x{outH}") + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + rgb_frames = [] + for chw in imgs: + # chw: (C,H,W) uint8, BGR from cv2.imdecode earlier + t = chw.detach().cpu() + if t.dtype != torch.uint8: + t = t.to(torch.uint8) + + # If grayscale, repeat to 3 channels + if t.shape[0] == 1: + t = t.repeat(3, 1, 1) + + # Resize to (outH, outW) + t_resized = F.interpolate( + t.unsqueeze(0), # (1,C,H,W) + size=(outH, outW), + mode="bilinear", + align_corners=False, + ).squeeze(0) # (C,outH,outW) + + # BGR -> RGB, then (H,W,C) + hwc = t_resized.permute(1, 2, 0).contiguous() # (H,W,3), uint8 + rgb_frames.append(hwc) + + video_tensor = torch.stack(rgb_frames, dim=0) # (T, H, W, 3) uint8 + + # ----------------------------- + # 1) Try torchvision.write_video + # ----------------------------- + try: + from torchvision.io import write_video + + write_video( + filename=str(output_path), + video_array=video_tensor, + fps=float(fps), + video_codec="libx264", # H.264, web-compatible + options={"crf": "23", "preset": "veryfast"}, + ) + print( + f"[MP4] Saved web-compatible H.264 preview via torchvision to {output_path}" + ) + return + except Exception as e: + print( + f"[MP4] torchvision.io.write_video failed ({e}); trying ffmpeg CLI fallback..." + ) + + # ----------------------------- + # 2) Fallback: ffmpeg CLI (libx264) + # ----------------------------- + ffmpeg = shutil.which("ffmpeg") + if ffmpeg is None: + raise RuntimeError( + "[MP4] Could not write web-compatible MP4:\n" + " - torchvision.io.write_video is unavailable or failed\n" + " - `ffmpeg` CLI not found on PATH\n" + "Install either torchvision with video support or ffmpeg+libx264." + ) + + # For ffmpeg rawvideo, we need BGR24 frames of shape (outH, outW, 3) + # We can convert our RGB hwc tensors back to BGR numpy. + cmd = [ + ffmpeg, + "-y", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-pix_fmt", + "bgr24", + "-s", + f"{outW}x{outH}", + "-r", + str(fps), + "-i", + "-", # stdin + "-an", + "-c:v", + "libx264", + "-pix_fmt", + "yuv420p", + "-profile:v", + "baseline", + "-level", + "3.0", + "-movflags", + "+faststart", + "-preset", + "veryfast", + "-crf", + "23", + str(output_path), + ] + + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + try: + for hwc_rgb in rgb_frames: + # hwc_rgb: (H,W,3), RGB uint8 + np_rgb = hwc_rgb.numpy() + # RGB -> BGR + np_bgr = np_rgb[..., ::-1] + proc.stdin.write(np_bgr.tobytes()) + finally: + if proc.stdin: + proc.stdin.flush() + proc.stdin.close() + + ret = proc.wait() + if ret != 0: + stderr = proc.stderr.read().decode(errors="ignore") if proc.stderr else "" + raise RuntimeError( + f"[MP4] ffmpeg/libx264 encoding failed (exit {ret}).\n{stderr}" + ) + + print( + f"[MP4] Saved web-compatible H.264 preview via ffmpeg CLI to {output_path}" + ) + + def extract_episode_iterative(self, episode_path, task_description: str = ""): + """ + TODO: Implement the iterative approach to save memory. + Extracts frames from an episode and saves them to the dataset. + Parameters + ---------- + episode_path : str + The path to the episode file. + task_description : str, optional + A description of the task associated with the episode (default is an empty string). + Returns + ------- + None + """ + writer = ZarrWriter( + episode_path=episode_path, + fps=self.fps, + embodiment=self.embodiment, + enable_sharding=False, + task="", + ) + with writer.write_incremental(total_frames=total_frames) as inc: + image_frames = [] + for i, frame in enumerate(AriaVRSExtractor.iter_episode_frames(episode_path, self.features, self.image_compressed, self.arm, self.prestack, self.benchmark)): + self.buffer.append(frame) + if self._mp4_path is not None: + image = frame["observations.images.front_img_1"] + image_frames.append(image) + + if len(self.buffer) == EPISODE_LENGTH: + + for f in self.buffer: + self.dataset.add_frame(f) + + + self.logger.info(f"Saving Episode after {i + 1} frames...") + self.dataset.save_episode(task=task_description) + self.buffer.clear() + if self._mp4_path is not None: + ep_stem = Path(episode_path).stem + mp4_path = self._mp4_path / f"{ep_stem}_video.mp4" + self.save_preview_mp4(image_frames, mp4_path, self.fps, self.image_compressed) + + def extract_episode(self, episode_path, task_description: str = "", output_dir: Path = Path("."), dataset_name: str = ""): + """ + Extracts frames from an episode and saves them to the dataset. + Parameters + ---------- + episode_path : str + The path to the episode file. + task_description : str, optional + A description of the task associated with the episode (default is an empty string). + Returns + ------- + None + """ + episode_name = dataset_name + + episode_feats = AriaVRSExtractor.process_episode( + episode_path=episode_path, + arm=self.arm, + benchmark=self.benchmark, + ) + numeric_data = {} + + image_data = {} + for key, value in episode_feats.items(): + if "images" in key: + if key in self.feats_to_zarr_keys: + image_data[self.feats_to_zarr_keys[key]] = value + else: + image_data[key] = value + else: + if key in self.feats_to_zarr_keys: + numeric_data[self.feats_to_zarr_keys[key]] = value + else: + numeric_data[key] = value + zarr_path = ZarrWriter.create_and_write( + episode_path=output_dir / f"{episode_name}.zarr", + numeric_data=numeric_data if numeric_data else None, + image_data=image_data if image_data else None, + fps=self.fps, + embodiment=self.embodiment, + enable_sharding=False, + task="", + ) + mp4_path = output_dir / f"{episode_name}.mp4" + W, H = 960, 720 + p = start_ffmpeg_mp4(mp4_path, W, H, fps=30, pix_fmt="rgb24") + for video_images in AriaVRSExtractor.iter_images(episode_path, chunk_length=256, height=H, width=W, focal_mult=3): + for image in video_images: + image = prep_frame(image, H, W) + if image is None: + continue + p.stdin.write(image.tobytes()) + p.stdin.close() + p.wait() + return zarr_path, mp4_path + + + def extract_episodes(self, episode_description: str = "", output_dir: Path = Path("."), dataset_name: str = ""): + """ + Extracts episodes from the episode list and processes them. + Parameters + ---------- + episode_description : str, optional + A description of the task to be passed to the extract_episode method (default is ''). + Raises + ------ + Exception + If an error occurs during the processing of an episode, it will be caught and printed. + Notes + ----- + After processing all episodes, the dataset is consolidated. + """ + + os.makedirs(output_dir, exist_ok=True) + with mem_section("extract_episodes", enabled=self.benchmark): + for episode_path in self.episode_list: + try: + return self.extract_episode(episode_path, task_description=episode_description, output_dir=output_dir, dataset_name=dataset_name) + except Exception as e: + self.logger.error(f"Error processing episode {episode_path}: {e}") + traceback.print_exc() + continue + + return None + +def argument_parse(): + parser = argparse.ArgumentParser( + description="Convert Aria VRS dataset to LeRobot-Robomimic hybrid and push to Hugging Face hub." + ) + + # Required arguments + parser.add_argument("--dataset-name", type=str, required=True, help="Name for dataset") + parser.add_argument( + "--raw-path", + type=Path, + required=True, + help="Directory containing the vrs, vrs_json, and the processed mps folder.", + ) + parser.add_argument( + "--fps", type=int, required=True, help="Frames per second for the dataset." + ) + # Optional arguments + parser.add_argument( + "--description", + type=str, + default="Aria recorded dataset.", + help="Description of the dataset.", + ) + parser.add_argument( + "--arm", + type=str, + choices=["left", "right", "bimanual"], + default="bimanual", + help="Specify the arm for processing.", + ) + parser.add_argument( + "--image-compressed", + type=str2bool, + default=False, + help="Set to True if the images are compressed.", + ) + parser.add_argument( + "--video-encoding", + type=str2bool, + default=False, + help="Set to True to encode images as videos.", + ) + + # Performance tuning arguments + parser.add_argument( + "--nproc", type=int, default=8, help="Number of image writer processes." + ) + parser.add_argument( + "--nthreads", type=int, default=2, help="Number of image writer threads." + ) + + # Debugging and output configuration + parser.add_argument( + "--output-dir", + type=Path, + default=Path(LEROBOT_HOME), + help="Directory where the processed dataset will be stored. Defaults to LEROBOT_HOME.", + ) + parser.add_argument( + "--debug", action="store_true", help="Store only 2 episodes for debug purposes." + ) + + parser.add_argument( + "--save-mp4", + type=str2bool, + default=True, + help="If True, save a single half-resolution MP4 with all frames across episodes.", + ) + + parser.add_argument( + "--benchmark", + action="store_true", + help="Run benchmark mode. Which include printing out the peak RAM usage of each section.", + ) + + args = parser.parse_args() + + return args + + +def main(args): + """ + Convert ARIA VRS files and push to Hugging Face hub. + + Parameters + ---------- + args : argparse.Namespace + Parsed command-line arguments. + """ + print( + args.video_encoding, + "-------------------------------------------------------------------------------------------------------", + ) + + # Initialize the dataset converter + converter = DatasetConverter( + raw_path=args.raw_path, + fps=args.fps, + arm=args.arm, + image_compressed=args.image_compressed, + encode_as_videos=args.video_encoding, + image_writer_processes=args.nproc, + image_writer_threads=args.nthreads, + debug=args.debug, + benchmark=args.benchmark, + ) + + gc.collect() + ctypes.CDLL("libc.so.6").malloc_trim(0) + # Extract episodes + return converter.extract_episode(episode_description=args.description, output_dir=args.output_dir, dataset_name=args.dataset_name) + +if __name__ == "__main__": + args = argument_parse() + main(args) diff --git a/egomimic/scripts/aria_process/aria_utils.py b/egomimic/scripts/aria_process/aria_utils.py index 599e9a47..a66b97f6 100644 --- a/egomimic/scripts/aria_process/aria_utils.py +++ b/egomimic/scripts/aria_process/aria_utils.py @@ -27,11 +27,11 @@ def build_camera_matrix(provider, pose_t): return T_world_rgb_camera -def undistort_to_linear(provider, stream_ids, raw_image, camera_label="rgb"): +def undistort_to_linear(provider, stream_ids, raw_image, camera_label="rgb", height=480, width=640, focal_mult=2): camera_label = provider.get_label_from_stream_id(stream_ids[camera_label]) calib = provider.get_device_calibration().get_camera_calib(camera_label) warped = calibration.get_linear_camera_calibration( - 480, 640, 133.25430222 * 2, camera_label, calib.get_transform_device_camera() + height, width, 133.25430222 * focal_mult, camera_label, calib.get_transform_device_camera() ) warped_image = calibration.distort_by_calibration(raw_image, warped, calib) warped_rot = np.rot90(warped_image, k=3) @@ -106,8 +106,7 @@ def slam_to_rgb(provider): return transform - -def compute_coordinate_frame(palm_pose, wrist_pose, palm_normal): +def compute_orientation_rotation_matrix(palm_pose, wrist_pose, palm_normal): x_axis = wrist_pose - palm_pose x_axis = np.ravel(x_axis) / np.linalg.norm(x_axis) z_axis = np.ravel(palm_normal) / np.linalg.norm(palm_normal) @@ -117,30 +116,8 @@ def compute_coordinate_frame(palm_pose, wrist_pose, palm_normal): x_axis = np.cross(z_axis, y_axis) x_axis = np.ravel(x_axis) / np.linalg.norm(x_axis) - return -1 * x_axis, y_axis, z_axis - - -def transform_coordinates(palm_pose, x_axis, y_axis, z_axis, transform): - palm_pose_h = np.append(palm_pose, 1) - x_axis_h = np.append(x_axis, 0) - y_axis_h = np.append(y_axis, 0) - z_axis_h = np.append(z_axis, 0) - - # Apply SLAM-to-RGB transformation - transformed_palm_pose = (transform @ palm_pose_h)[:3] - transformed_x_axis = (transform @ x_axis_h)[:3] - transformed_y_axis = (transform @ y_axis_h)[:3] - transformed_z_axis = (transform @ z_axis_h)[:3] - - # Apply additional rotation transpose - rot_T = ROTATION_MATRIX.T # Compute the transpose - final_palm_pose = rot_T @ transformed_palm_pose - final_x_axis = rot_T @ transformed_x_axis - final_y_axis = rot_T @ transformed_y_axis - final_z_axis = rot_T @ transformed_z_axis - - return final_palm_pose, final_x_axis, final_y_axis, final_z_axis - + rot_matrix = np.column_stack([-1 * x_axis, y_axis, z_axis]) + return rot_matrix def coordinate_frame_to_ypr(x_axis, y_axis, z_axis): rot_matrix = np.column_stack([x_axis, y_axis, z_axis]) diff --git a/egomimic/scripts/aria_process/run_aria_conversion.py b/egomimic/scripts/aria_process/run_aria_conversion.py index c3b80f2e..05aeabc5 100644 --- a/egomimic/scripts/aria_process/run_aria_conversion.py +++ b/egomimic/scripts/aria_process/run_aria_conversion.py @@ -25,16 +25,18 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, Iterator, Tuple +from tqdm import tqdm + +import ray +from ray.exceptions import OutOfMemoryError, RayTaskError, WorkerCrashedError + +from egomimic.utils.aws.aws_data_utils import get_boto3_s3_client, load_env, s3_sync_to_local, upload_dir_to_s3, get_cloudpathlib_s3_client import boto3 import ray # --- Conversion wrapper ------------------------------------------------------ -from aria_helper import lerobot_job -from cloudpathlib import S3Path -from ray.exceptions import OutOfMemoryError, RayTaskError, WorkerCrashedError - -from egomimic.utils.aws.aws_data_utils import s3_sync_to_local, upload_dir_to_s3 +from aria_helper import zarr_job # --- SQL helpers -------------------------------------------------------------- from egomimic.utils.aws.aws_sql import ( @@ -46,16 +48,16 @@ ) # --- Paths ------------------------------------------------------------------- -RAW_REMOTE_PREFIX = os.environ.get("RAW_REMOTE_PREFIX", "s3://rldb/raw_v2/aria").rstrip( - "/" -) +RAW_REMOTE_PREFIX = os.environ.get("RAW_REMOTE_PREFIX", "s3://rldb/raw_v2/aria/").rstrip("/") PROCESSED_ROOT = Path("/home/ubuntu/processed") PROCESSED_LOCAL_ROOT = Path( os.environ.get("PROCESSED_LOCAL_ROOT", "/home/ubuntu/processed") ).resolve() PROCESSED_REMOTE_PREFIX = os.environ.get( - "PROCESSED_REMOTE_PREFIX", "s3://rldb/processed_v2/aria" -).rstrip("/") + "PROCESSED_REMOTE_PREFIX", "s3://rldb/processed_v3/aria/" +).rstrip( + "/" +) BUCKET = os.environ.get("BUCKET", "rldb") LOG_ROOT = Path( @@ -65,12 +67,12 @@ ) ).resolve() - # --- Utilities --------------------------------------------------------------- def ensure_path_ready(p: str | Path | S3Path, retries: int = 30) -> bool: if isinstance(p, str): if p.startswith("s3://"): - p = S3Path(p) + s3_client = get_cloudpathlib_s3_client() + p = S3Path(p, client=s3_client) else: p = Path(p) for _ in range(retries): @@ -122,7 +124,8 @@ def iter_vrs_bundles(root_s3: str) -> Iterator[Tuple[S3Path, S3Path, S3Path]]: root_s3: like "s3://rldb/raw_v2/aria/" Returns S3Path objects (cloudpathlib), not local filesystem paths. """ - root = S3Path(root_s3) + s3_client = get_cloudpathlib_s3_client() + root = S3Path(root_s3, client=s3_client) for vrs in sorted(root.glob("*.vrs"), key=lambda p: p.name): name = vrs.stem @@ -150,7 +153,8 @@ def iter_vrs_bundles_fast(root_s3: str) -> Iterator[Tuple[S3Path, S3Path, S3Path Uses a single `root.walk(...)` traversal and avoids per-path `.exists()` / `.is_dir()`. """ - root = S3Path(root_s3) + s3_client = get_cloudpathlib_s3_client() + root = S3Path(root_s3, client=s3_client) vrs_by_name: dict[str, S3Path] = {} has_json: set[str] = set() @@ -279,7 +283,6 @@ def convert_one_bundle_impl( s3_processed_dir: str, dataset_name: str, arm: str, - description: str, ) -> tuple[str, str, int]: """ Perform conversion for a single episode. @@ -288,11 +291,12 @@ def convert_one_bundle_impl( • mp4_path: per-episode MP4 ('' if not created) • total_frames: -1 if unknown/failure """ - vrs = S3Path(vrs) if isinstance(vrs, str) else vrs - jsonf = S3Path(jsonf) if isinstance(jsonf, str) else jsonf - mps_dir = S3Path(mps_dir) if isinstance(mps_dir, str) else mps_dir + s3_client = get_cloudpathlib_s3_client() + boto3_client = get_boto3_s3_client() + vrs = S3Path(vrs, client=s3_client) if isinstance(vrs, str) else vrs + jsonf = S3Path(jsonf, client=s3_client) if isinstance(jsonf, str) else jsonf + mps_dir = S3Path(mps_dir, client=s3_client) if isinstance(mps_dir, str) else mps_dir - s3 = boto3.client("s3") stem = vrs.stem LOG_ROOT.mkdir(parents=True, exist_ok=True) log_path = LOG_ROOT / f"{stem}-{uuid.uuid4().hex[:8]}.log" @@ -311,10 +315,8 @@ def convert_one_bundle_impl( mps_dir, ] - raw_bucket, raw_prefix = _parse_s3_uri( - RAW_REMOTE_PREFIX, default_bucket=BUCKET - ) - raw_root = S3Path(RAW_REMOTE_PREFIX) + raw_bucket, raw_prefix = _parse_s3_uri(RAW_REMOTE_PREFIX, default_bucket=BUCKET) + raw_root = S3Path(RAW_REMOTE_PREFIX, client=s3_client) for t in targets: if not ensure_path_ready(t): @@ -330,36 +332,40 @@ def convert_one_bundle_impl( if t.is_dir(): s3_sync_to_local(raw_bucket, t_key, str(link)) else: - s3.download_file(raw_bucket, t_key, str(link)) + boto3_client.download_file(raw_bucket, t_key, str(link)) except Exception as e: print(f"[ERR] aws copy failed for {t}: {e}", flush=True) shutil.rmtree(tmp_dir, ignore_errors=True) return "", "", -1 + # TODO remove dataset_name we no longer use it for path or anything ds_parent = Path(out_dir) ds_parent.mkdir(parents=True, exist_ok=True) - ds_path = ds_parent / dataset_name + ds_path = ds_parent / f"{dataset_name}.zarr" try: print(f"[INFO] Converting: {stem} → {ds_path} (arm={arm})", flush=True) - lerobot_job( + zarr_path, mp4_path = zarr_job( raw_path=str(tmp_dir), output_dir=str(ds_parent), dataset_name=dataset_name, arm=arm, - description=description or "", ) frames = -1 - info = ds_path / "meta/info.json" + info = zarr_path / "zarr.json" + if info.exists(): try: meta = json.loads(info.read_text()) - frames = int(meta.get("total_frames", -1)) - except Exception: + frames = int(meta.get("attributes", {}).get("total_frames", -1)) + except Exception as e: + print(f"[ERR] Failed to load info: {e}") frames = -1 + else: + print(f"[ERR] Info not found: {info}") - candidate = ds_parent / f"{stem}_video.mp4" + candidate = mp4_path if candidate.exists(): mp4_str = str(candidate) else: @@ -375,6 +381,7 @@ def convert_one_bundle_impl( ) ds_s3_prefix = f"{out_prefix.rstrip('/')}/{ds_rel}".strip("/") upload_dir_to_s3(str(ds_path), out_bucket, prefix=ds_s3_prefix) + shutil.rmtree(str(ds_path), ignore_errors=True) if mp4_str: mp4_path = Path(mp4_str) if mp4_path.exists(): @@ -387,13 +394,15 @@ def convert_one_bundle_impl( "/" ) try: - s3.upload_file(str(mp4_path), out_bucket, mp4_s3_key) + boto3_client.upload_file(str(mp4_path), out_bucket, mp4_s3_key) + mp4_path.unlink(missing_ok=True) except Exception as e: raise Exception( f"Failed to upload mp4 {mp4_path} to S3: {e}" ) except Exception as e: print(f"[ERR] Failed to upload {ds_path} to S3: {e}", flush=True) + traceback.print_exc() return "", "", -2 return str(ds_path), mp4_str, frames @@ -406,12 +415,12 @@ def convert_one_bundle_impl( shutil.rmtree(tmp_dir, ignore_errors=True) -@ray.remote(num_cpus=8, resources={"aria_small": 1}) +@ray.remote(num_cpus=2, resources={"aria_small": 1}) def convert_one_bundle_small(*args, **kwargs): return convert_one_bundle_impl(*args, **kwargs) -@ray.remote(num_cpus=32, resources={"aria_big": 1}) +@ray.remote(num_cpus=8, resources={"aria_big": 1}) def convert_one_bundle_big(*args, **kwargs): return convert_one_bundle_impl(*args, **kwargs) @@ -457,17 +466,17 @@ def launch( print(f"[SKIP] {name}: no matching row in SQL (app.episodes)", flush=True) continue - processed_path = (row.processed_path or "").strip() + processed_path = (row.zarr_processed_path or "").strip() if skip_if_done and len(processed_path) > 0: print( - f"[SKIP] {name}: already has processed_path='{processed_path}'", + f"[SKIP] {name}: already has zarr_processed_path='{processed_path}'", flush=True, ) continue - if row.processing_error != "": + if row.zarr_processing_error != "": print( - f"[INFO] skipping episode hash: {row.episode_hash} due to processing error", + f"[INFO] skipping episode hash: {row.episode_hash} due to zarr processing error", flush=True, ) continue @@ -482,8 +491,6 @@ def launch( dataset_name = episode_key out_dir = PROCESSED_ROOT s3out_dir = PROCESSED_REMOTE_PREFIX - description = row.task_description or "" - if dry: ds_path = (PROCESSED_ROOT / dataset_name).resolve() stem = vrs.stem @@ -494,10 +501,9 @@ def launch( print( f"[DRY] {name}: arm={arm} | out_dir={out_dir}/{dataset_name}\n" - f" desc='{description[:60]}'\n" f" would write to SQL:\n" - f" processed_path={mapped_ds}\n" - f" mp4_path={mapped_mp4}", + f" zarr_processed_path={mapped_ds}\n" + f" zarr_mp4_path={mapped_mp4}", flush=True, ) continue @@ -510,7 +516,6 @@ def launch( str(s3out_dir), dataset_name, arm, - description, ) start_time = time.time() @@ -551,36 +556,36 @@ def launch( row.num_frames = int(frames) if frames is not None else -1 if row.num_frames > 0: - row.processed_path = _map_processed_local_to_remote(ds_path) - row.mp4_path = _map_processed_local_to_remote(mp4_path) - row.processing_error = "" + row.zarr_processed_path = _map_processed_local_to_remote(ds_path) + row.zarr_mp4_path = _map_processed_local_to_remote(mp4_path) + row.zarr_processing_error = "" elif row.num_frames == -2: - row.processed_path = "" - row.mp4_path = "" - row.processing_error = "Upload Failed" + row.zarr_processed_path = "" + row.zarr_mp4_path = "" + row.zarr_processing_error = "Upload Failed" elif row.num_frames == -1: - row.processed_path = "" - row.mp4_path = "" - row.processing_error = "Zero Frames" + row.zarr_processed_path = "" + row.zarr_mp4_path = "" + row.zarr_processing_error = "Zero Frames" else: - row.processed_path = "" - row.mp4_path = "" - row.processing_error = "Conversion Failed Unhandled Error" + row.zarr_processed_path = "" + row.zarr_mp4_path = "" + row.zarr_processing_error = "Conversion Failed Unhandled Error" update_episode(engine, row) print( f"[OK] Updated SQL for {episode_key}: " - f"processed_path={row.processed_path}, num_frames={row.num_frames}, " + f"zarr_processed_path={row.zarr_processed_path}, num_frames={row.num_frames}, " f"duration_sec={duration_sec:.2f}", flush=True, ) - if row.num_frames > 0 and row.processed_path: + if row.num_frames > 0 and row.zarr_processed_path: benchmark_rows.append( { "episode_key": episode_key, - "processed_path": row.processed_path, - "mp4_path": row.mp4_path, + "processed_path": row.zarr_processed_path, + "mp4_path": row.zarr_mp4_path, "num_frames": row.num_frames, "duration_sec": duration_sec, } @@ -610,13 +615,13 @@ def launch( # mark failed in SQL (so skip-if-done won't think it's done) row.num_frames = -1 - row.processed_path = "" - row.mp4_path = "" - row.processing_error = f"{type(e).__name__}: {e}" + row.zarr_processed_path = "" + row.zarr_mp4_path = "" + row.zarr_processing_error = f"{type(e).__name__}: {e}" try: update_episode(engine, row) print( - f"[FAIL] Marked SQL failed for {episode_key} (cleared processed_path)", + f"[FAIL] Marked SQL failed for {episode_key} (cleared zarr_processed_path)", flush=True, ) except Exception as ee: @@ -657,7 +662,7 @@ def main(): p.add_argument( "--skip-if-done", action="store_true", - help="Skip episodes that already have a processed_path in SQL", + help="Skip episodes that already have a zarr_processed_path in SQL", ) p.add_argument( "--ray-address", default="auto", help="Ray cluster address (default: auto)" @@ -671,6 +676,18 @@ def main(): p.add_argument("--debug", action="store_true") args = p.parse_args() + env_vars = {} + load_env() + for k in [ + "R2_ACCESS_KEY_ID", + "R2_SECRET_ACCESS_KEY", + "R2_SESSION_TOKEN", # optional + "R2_ENDPOINT_URL", # optional; include if your helper expects it + ]: + v = os.environ.get(k) + if v: + env_vars[k] = v + if args.debug: runtime_env = { "working_dir": "/home/ubuntu/EgoVerse", @@ -685,8 +702,8 @@ def main(): ], } else: - runtime_env = None - + runtime_env = {} + runtime_env["env_vars"] = env_vars ray.init(address=args.ray_address, runtime_env=runtime_env) launch( dry=args.dry_run, diff --git a/egomimic/scripts/eva_process/eva-cluster.yaml b/egomimic/scripts/eva_process/eva-cluster.yaml index 0ba3dfb8..60c5a948 100644 --- a/egomimic/scripts/eva_process/eva-cluster.yaml +++ b/egomimic/scripts/eva_process/eva-cluster.yaml @@ -20,9 +20,11 @@ rsync_exclude: - "**/__pycache__/" - "*.pyc" - "logs/" + - "*.parquet" + - "**/pytracik/**" -max_workers: 50 -idle_timeout_minutes: 10 +max_workers: 140 +idle_timeout_minutes: 5 available_node_types: head_node: @@ -30,6 +32,8 @@ available_node_types: InstanceType: t3a.2xlarge KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 BlockDeviceMappings: - DeviceName: /dev/sda1 Ebs: { VolumeSize: 50 } @@ -37,17 +41,33 @@ available_node_types: min_workers: 0 max_workers: 0 - worker_node: + worker_small: node_config: - InstanceType: c5.18xlarge + InstanceType: r6a.2xlarge KeyName: rldb-base-pem ImageId: ami-08c4d70c31f91a5ac + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 BlockDeviceMappings: - DeviceName: /dev/sda1 - Ebs: { VolumeSize: 50 } - resources: { CPU: 72 } + Ebs: { VolumeSize: 150 } + resources: { CPU: 2, eva_small: 1 } min_workers: 0 - max_workers: 50 + max_workers: 120 + + worker_big: + node_config: + InstanceType: r6a.8xlarge + IamInstanceProfile: + Arn: arn:aws:iam::556885871428:instance-profile/ray-autoscaler-v1 + KeyName: rldb-base-pem + ImageId: ami-08c4d70c31f91a5ac + BlockDeviceMappings: + - DeviceName: /dev/sda1 + Ebs: { VolumeSize: 200 } # often worth bumping + resources: { CPU: 8, eva_big: 1 } + min_workers: 0 + max_workers: 10 head_node_type: head_node @@ -60,79 +80,42 @@ initialization_commands: - sudo DEBIAN_FRONTEND=noninteractive apt-get install -yq libgl1 libpq-dev awscli - sudo DEBIAN_FRONTEND=noninteractive apt-get install -yq ffmpeg libavcodec-dev libavformat-dev libswscale-dev libx264-dev - | - echo "=== Testing FFmpeg installation ===" - if command -v ffmpeg >/dev/null 2>&1; then - ffmpeg -hide_banner -version | head -n 1 - echo "[OK] FFmpeg is installed and accessible." - else - echo "[FAIL] FFmpeg not found in PATH." - fi - echo "=== FFmpeg test done ===" + echo "=== Testing FFmpeg installation ===" + if command -v ffmpeg >/dev/null 2>&1; then + ffmpeg -hide_banner -version | head -n 1 + echo "[OK] FFmpeg is installed and accessible." + else + echo "[FAIL] FFmpeg not found in PATH." + fi + echo "=== FFmpeg test done ===" - | - cd ~/EgoVerse - pip3 install --no-input -r requirements-ray.txt - pip3 install --no-input -e external/lerobot - pip3 install --no-input -e . - + cd ~/EgoVerse + pip3 install --no-input -r requirements-ray.txt + pip3 install --no-input -e external/lerobot + pip3 install --no-input -e . setup_commands: - sudo mkdir -p /mnt/raw /mnt/processed - | - chmod +x ~/EgoVerse/egomimic/utils/aws/setup_secret.sh - R2_SECRET_NAME=r2/rldb/credentials DB_SECRET_NAME=rds/appdb/appuser REGION=us-east-2 \ - bash ~/EgoVerse/egomimic/utils/aws/setup_secret.sh - - | - set -a - . /home/ubuntu/.egoverse_env - set +a - printf '%s:%s\n' "$R2_ACCESS_KEY_ID" "$R2_SECRET_ACCESS_KEY" | sudo tee /etc/passwd-s3fs >/dev/null - sudo chmod 600 /etc/passwd-s3fs - - - | - if mountpoint -q /mnt/raw; then - echo "/mnt/raw already mounted" - else - set -a - . /home/ubuntu/.egoverse_env - set +a - sudo s3fs rldb:/raw_v2/eva /mnt/raw \ - -o ro \ - -o url="$AWS_ENDPOINT_URL_S3" \ - -o passwd_file=/etc/passwd-s3fs \ - -o allow_other \ - -o umask=000 \ - -o use_path_request_style \ - -o nonempty - fi - + chmod +x ~/EgoVerse/egomimic/utils/aws/setup_secret.sh + R2_SECRET_NAME=r2/rldb/credentials DB_SECRET_NAME=rds/appdb/appuser REGION=us-east-2 \ + bash ~/EgoVerse/egomimic/utils/aws/setup_secret.sh - | - if mountpoint -q /mnt/processed; then - echo "/mnt/processed already mounted" - else - set -a - . /home/ubuntu/.egoverse_env - set +a - sudo s3fs rldb:/processed_v2/eva /mnt/processed \ - -o url="$AWS_ENDPOINT_URL_S3" \ - -o passwd_file=/etc/passwd-s3fs \ - -o allow_other \ - -o umask=000 \ - -o use_path_request_style \ - -o multipart_size=64 \ - -o parallel_count=20 \ - -o del_cache \ - -o complement_stat \ - -o nonempty - fi + set -a + . /home/ubuntu/.egoverse_env + set +a + printf '%s:%s\n' "$R2_ACCESS_KEY_ID" "$R2_SECRET_ACCESS_KEY" | sudo tee /etc/passwd-s3fs >/dev/null + sudo chmod 600 /etc/passwd-s3fs +head_setup_commands: - | - (crontab -l 2>/dev/null \ - | grep -v run_eva_conversion.py \ - | grep -v ray_worker_gaurdrails.py \ - | grep -v ray_worker_gaurdrails.lock ; \ - echo 'CRON_TZ=America/New_York'; \ - echo '0 22 * * * flock -n /tmp/run_eva_conversion.lock /bin/bash -lc "set -a; . /home/ubuntu/.egoverse_env; set +a; /usr/bin/python3 ~/EgoVerse/egomimic/scripts/eva_process/run_eva_conversion.py --skip-if-done" >> ~/eva_conversion_$(date +\%Y-\%m-\%d-\%H-\%M-\%S).log 2>&1'; \ - echo '*/10 * * * * flock -n /tmp/ray_worker_gaurdrails.lock /usr/bin/python3 /home/ubuntu/EgoVerse/egomimic/utils/aws/budget_guardrails/ray_worker_gaurdrails.py >> /home/ubuntu/ray_worker_gaurdrails.log 2>&1') \ - | crontab - || true + (crontab -l 2>/dev/null \ + | grep -v run_eva_conversion.py \ + | grep -v ray_worker_gaurdrails.py \ + | grep -v ray_worker_gaurdrails.lock ; \ + echo 'CRON_TZ=America/New_York'; \ + echo '0 22 * * * flock -n /tmp/run_eva_conversion.lock /bin/bash -lc "set -a; . /home/ubuntu/.egoverse_env; set +a; /usr/bin/python3 ~/EgoVerse/egomimic/scripts/eva_process/run_eva_conversion.py --skip-if-done" >> ~/eva_conversion_$(date +\%Y-\%m-\%d-\%H-\%M-\%S).log 2>&1'; \ + echo '*/10 * * * * flock -n /tmp/ray_worker_gaurdrails.lock /usr/bin/python3 /home/ubuntu/EgoVerse/egomimic/utils/aws/budget_guardrails/ray_worker_gaurdrails.py >> /home/ubuntu/ray_worker_gaurdrails.log 2>&1') \ + | crontab - || true - crontab -l || true diff --git a/egomimic/scripts/eva_process/eva_helper.py b/egomimic/scripts/eva_process/eva_helper.py index 4ff6fa50..ceb2c35e 100644 --- a/egomimic/scripts/eva_process/eva_helper.py +++ b/egomimic/scripts/eva_process/eva_helper.py @@ -1,8 +1,8 @@ from pathlib import Path from types import SimpleNamespace -from egomimic.scripts.eva_process.eva_to_lerobot import main as eva_main - +from egomimic.scripts.eva_process.eva_to_zarr import main as zarr_main +from egomimic.scripts.eva_process.eva_to_lerobot import main as lerobot_main def lerobot_job( *, @@ -13,6 +13,7 @@ def lerobot_job( description: str = "", extrinsics_key: str = "x5Dec13_2", ) -> None: + raw_path = Path(raw_path).expanduser().resolve() output_dir = Path(output_dir).expanduser().resolve() @@ -37,4 +38,36 @@ def lerobot_job( save_mp4=True, ) - eva_main(args) + lerobot_main(args) + + +def zarr_job( + *, + raw_path: str | Path, + output_dir: str | Path, + dataset_name: str, + arm: str, + description: str = "", + extrinsics_key: str = "x5Dec13_2", + chunk_timesteps: int = 100, +) -> None: + + raw_path = Path(raw_path).expanduser().resolve() + output_dir = Path(output_dir).expanduser().resolve() + + args = SimpleNamespace( + raw_path=raw_path, + output_dir=output_dir, + name=dataset_name, + fps=30, + arm=arm, + extrinsics_key=extrinsics_key, + description=description, + image_compressed=False, + prestack=False, + debug=False, + save_mp4=True, + chunk_timesteps=chunk_timesteps, + ) + + zarr_main(args) diff --git a/egomimic/scripts/eva_process/eva_to_lerobot.py b/egomimic/scripts/eva_process/eva_to_lerobot.py index a5d71aef..68f7ab59 100644 --- a/egomimic/scripts/eva_process/eva_to_lerobot.py +++ b/egomimic/scripts/eva_process/eva_to_lerobot.py @@ -1218,6 +1218,157 @@ def define_features( return features, metadata +def save_preview_mp4( + frames: list[dict], output_path: Path, fps: int, image_compressed: bool +): + """ + Save a half-resolution, web-compatible MP4 (H.264, yuv420p). + + Strategy: + 1. Try torchvision.io.write_video (H.264 via FFmpeg libs, no CLI). + 2. If that fails, fall back to ffmpeg CLI via subprocess. + 3. If both fail, raise a RuntimeError. + + Expects each frame dict to contain: + 'observations.images.front_img_1' -> torch.Tensor (C,H,W), uint8, BGR. + """ + img_key = "observations.images.front_img_1" + imgs = [f[img_key] for f in frames if img_key in f] + if not imgs: + print(f"[MP4] No frames with key '{img_key}' found — skipping video save.") + return + + # Assume imgs[0] is (C,H,W) + C, H, W = imgs[0].shape + + # Compute half-res (force even dims for yuv420p) + outW, outH = W // 2, H // 2 + if outW % 2: + outW -= 1 + if outH % 2: + outH -= 1 + if outW <= 0 or outH <= 0: + raise ValueError(f"[MP4] Invalid output size: {outW}x{outH}") + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Build resized RGB frames once + rgb_frames = [] + for chw in imgs: + # chw: (C,H,W) uint8, BGR from cv2.imdecode earlier + t = chw.detach().cpu() + if t.dtype != torch.uint8: + t = t.to(torch.uint8) + + # If grayscale, repeat to 3 channels + if t.shape[0] == 1: + t = t.repeat(3, 1, 1) + + # Resize to (outH, outW) + t_resized = F.interpolate( + t.unsqueeze(0), # (1,C,H,W) + size=(outH, outW), + mode="bilinear", + align_corners=False, + ).squeeze(0) # (C,outH,outW) + + hwc = t_resized.permute(1, 2, 0).contiguous() # (H,W,3), uint8 + rgb_frames.append(hwc) + + video_tensor = torch.stack(rgb_frames, dim=0) # (T, H, W, 3) uint8 + + # 1) Try torchvision.write_video + try: + from torchvision.io import write_video + + write_video( + filename=str(output_path), + video_array=video_tensor, + fps=float(fps), + video_codec="libx264", # H.264, web-compatible + options={"crf": "23", "preset": "veryfast"}, + ) + print( + f"[MP4] Saved web-compatible H.264 preview via torchvision to {output_path}" + ) + return + except Exception as e: + print( + f"[MP4] torchvision.io.write_video failed ({e}); trying ffmpeg CLI fallback..." + ) + + # 2) Fallback: ffmpeg CLI (libx264) + ffmpeg = shutil.which("ffmpeg") + if ffmpeg is None: + raise RuntimeError( + "[MP4] Could not write web-compatible MP4:\n" + " - torchvision.io.write_video is unavailable or failed\n" + " - `ffmpeg` CLI not found on PATH\n" + "Install either torchvision with video support or ffmpeg+libx264." + ) + + # For ffmpeg rawvideo, we need BGR24 frames of shape (outH, outW, 3) + # We can convert our RGB hwc tensors back to BGR numpy. + cmd = [ + ffmpeg, + "-y", + "-f", + "rawvideo", + "-vcodec", + "rawvideo", + "-pix_fmt", + "bgr24", + "-s", + f"{outW}x{outH}", + "-r", + str(fps), + "-i", + "-", # stdin + "-an", + "-c:v", + "libx264", + "-pix_fmt", + "yuv420p", + "-profile:v", + "baseline", + "-level", + "3.0", + "-movflags", + "+faststart", + "-preset", + "veryfast", + "-crf", + "23", + str(output_path), + ] + + proc = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + try: + for hwc_rgb in rgb_frames: + # hwc_rgb: (H,W,3), RGB uint8 + np_rgb = hwc_rgb.numpy() + # RGB -> BGR + np_bgr = np_rgb[..., ::-1] + proc.stdin.write(np_bgr.tobytes()) + finally: + if proc.stdin: + proc.stdin.flush() + proc.stdin.close() + + ret = proc.wait() + if ret != 0: + stderr = proc.stderr.read().decode(errors="ignore") if proc.stderr else "" + raise RuntimeError(f"[MP4] ffmpeg/libx264 encoding failed (exit {ret}).\n{stderr}") + + print(f"[MP4] Saved web-compatible H.264 preview via ffmpeg CLI to {output_path}") + + class DatasetConverter: """ A class to convert datasets to Lerobot format. @@ -1337,162 +1488,7 @@ def __init__( def save_preview_mp4( self, frames: list[dict], output_path: Path, fps: int, image_compressed: bool ): - """ - Save a half-resolution, web-compatible MP4 (H.264, yuv420p). - - Strategy: - 1. Try torchvision.io.write_video (H.264 via FFmpeg libs, no CLI). - 2. If that fails, fall back to ffmpeg CLI via subprocess. - 3. If both fail, raise a RuntimeError. - - Expects each frame dict to contain: - 'observations.images.front_img_1' -> torch.Tensor (C,H,W), uint8, BGR. - """ - img_key = "observations.images.front_img_1" - imgs = [f[img_key] for f in frames if img_key in f] - if not imgs: - print(f"[MP4] No frames with key '{img_key}' found — skipping video save.") - return - - # Assume imgs[0] is (C,H,W) - C, H, W = imgs[0].shape - - # Compute half-res (force even dims for yuv420p) - outW, outH = W // 2, H // 2 - if outW % 2: - outW -= 1 - if outH % 2: - outH -= 1 - if outW <= 0 or outH <= 0: - raise ValueError(f"[MP4] Invalid output size: {outW}x{outH}") - - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - # ----------------------------- - # Build resized RGB frames once - # ----------------------------- - rgb_frames = [] - for chw in imgs: - # chw: (C,H,W) uint8, BGR from cv2.imdecode earlier - t = chw.detach().cpu() - if t.dtype != torch.uint8: - t = t.to(torch.uint8) - - # If grayscale, repeat to 3 channels - if t.shape[0] == 1: - t = t.repeat(3, 1, 1) - - # Resize to (outH, outW) - t_resized = F.interpolate( - t.unsqueeze(0), # (1,C,H,W) - size=(outH, outW), - mode="bilinear", - align_corners=False, - ).squeeze(0) # (C,outH,outW) - - hwc = t_resized.permute(1, 2, 0).contiguous() # (H,W,3), uint8 - rgb_frames.append(hwc) - - video_tensor = torch.stack(rgb_frames, dim=0) # (T, H, W, 3) uint8 - - # ----------------------------- - # 1) Try torchvision.write_video - # ----------------------------- - try: - from torchvision.io import write_video - - write_video( - filename=str(output_path), - video_array=video_tensor, - fps=float(fps), - video_codec="libx264", # H.264, web-compatible - options={"crf": "23", "preset": "veryfast"}, - ) - print( - f"[MP4] Saved web-compatible H.264 preview via torchvision to {output_path}" - ) - return - except Exception as e: - print( - f"[MP4] torchvision.io.write_video failed ({e}); trying ffmpeg CLI fallback..." - ) - - # ----------------------------- - # 2) Fallback: ffmpeg CLI (libx264) - # ----------------------------- - ffmpeg = shutil.which("ffmpeg") - if ffmpeg is None: - raise RuntimeError( - "[MP4] Could not write web-compatible MP4:\n" - " - torchvision.io.write_video is unavailable or failed\n" - " - `ffmpeg` CLI not found on PATH\n" - "Install either torchvision with video support or ffmpeg+libx264." - ) - - # For ffmpeg rawvideo, we need BGR24 frames of shape (outH, outW, 3) - # We can convert our RGB hwc tensors back to BGR numpy. - cmd = [ - ffmpeg, - "-y", - "-f", - "rawvideo", - "-vcodec", - "rawvideo", - "-pix_fmt", - "bgr24", - "-s", - f"{outW}x{outH}", - "-r", - str(fps), - "-i", - "-", # stdin - "-an", - "-c:v", - "libx264", - "-pix_fmt", - "yuv420p", - "-profile:v", - "baseline", - "-level", - "3.0", - "-movflags", - "+faststart", - "-preset", - "veryfast", - "-crf", - "23", - str(output_path), - ] - - proc = subprocess.Popen( - cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - try: - for hwc_rgb in rgb_frames: - # hwc_rgb: (H,W,3), RGB uint8 - np_rgb = hwc_rgb.numpy() - # RGB -> BGR - np_bgr = np_rgb[..., ::-1] - proc.stdin.write(np_bgr.tobytes()) - finally: - if proc.stdin: - proc.stdin.flush() - proc.stdin.close() - - ret = proc.wait() - if ret != 0: - stderr = proc.stderr.read().decode(errors="ignore") if proc.stderr else "" - raise RuntimeError( - f"[MP4] ffmpeg/libx264 encoding failed (exit {ret}).\n{stderr}" - ) - - print( - f"[MP4] Saved web-compatible H.264 preview via ffmpeg CLI to {output_path}" - ) + save_preview_mp4(frames, output_path, fps, image_compressed) def extract_episode(self, episode_path, task_description: str = ""): extrinsics = EXTRINSICS[self.extrinsics_key] diff --git a/egomimic/scripts/eva_process/eva_to_zarr.py b/egomimic/scripts/eva_process/eva_to_zarr.py new file mode 100644 index 00000000..7069741e --- /dev/null +++ b/egomimic/scripts/eva_process/eva_to_zarr.py @@ -0,0 +1,446 @@ +""" +Convert Eva HDF5 episodes to Zarr format. + +Mirrors the main(args) interface of eva_to_lerobot.py so that +run_eva_conversion.py can swap between LeRobot and Zarr backends. +""" + +import argparse +import logging +import traceback +from pathlib import Path + +import numpy as np +import torch +from scipy.spatial.transform import Rotation as R + +from egomimic.scripts.eva_process.zarr_utils import EvaHD5Extractor +from egomimic.rldb.zarr.zarr_writer import ZarrWriter +from egomimic.utils.egomimicUtils import EXTRINSICS, str2bool, xyzw_to_wxyz +"" +logger = logging.getLogger(__name__) + +DATASET_KEY_MAPPINGS = { + "observations.state.eepose": "obs_ee_pose", + "observations.state.joint_positions": "obs_joints", + "actions_base_cartesian": "cmd_ee_pose", + "actions_joints": "cmd_joints", + "observations.images.front_img_1": "images.front_1", + "observations.images.right_wrist_img": "images.right_wrist", + "observations.images.left_wrist_img": "images.left_wrist", +} + +# --------------------------------------------------------------------------- +# Preview MP4 +# --------------------------------------------------------------------------- + + +R_t_e = np.array([ + [0, 0, 1], + [-1, 0, 0], + [0, -1, 0], +], dtype=float) + +def eva_reorientation(quat: np.ndarray) -> np.ndarray: + rotation = R.from_quat(quat).as_matrix() + rotation = R_t_e @ rotation + return R.from_matrix(rotation).as_quat() + + +def save_preview_mp4( + images_tchw: np.ndarray, output_path: Path, fps: int +) -> None: + """Save a half-resolution, web-compatible MP4 from a (T,C,H,W) uint8 array. + + Tries torchvision.io.write_video first, then falls back to ffmpeg CLI. + + Parameters + ---------- + images_tchw : np.ndarray + Image array with shape (T, C, H, W), uint8. + output_path : Path + Destination path for the .mp4 file. + fps : int + Frames per second. + """ + import torch.nn.functional as F + + imgs = torch.from_numpy(np.asarray(images_tchw)) # (T,C,H,W) + if imgs.ndim != 4 or len(imgs) == 0: + logger.warning("No valid frames for preview MP4 — skipping.") + return + + _, _, H, W = imgs.shape + + # Half-resolution with even dims for yuv420p + outW, outH = W // 2, H // 2 + outW -= outW % 2 + outH -= outH % 2 + if outW <= 0 or outH <= 0: + raise ValueError(f"Invalid output size: {outW}x{outH}") + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + rgb_frames = [] + for chw in imgs: + t = chw.detach().cpu().to(torch.uint8) + if t.shape[0] == 1: + t = t.repeat(3, 1, 1) + t_resized = F.interpolate( + t.unsqueeze(0).float(), + size=(outH, outW), + mode="bilinear", + align_corners=False, + ).squeeze(0).to(torch.uint8) + rgb_frames.append(t_resized.permute(1, 2, 0).contiguous()) # (H,W,3) + + video_tensor = torch.stack(rgb_frames, dim=0) # (T,H,W,3) uint8 + + # Try torchvision first + try: + import torchvision.io + + torchvision.io.write_video( + str(output_path), video_tensor, fps=fps, + video_codec="libx264", + options={"crf": "23", "pix_fmt": "yuv420p"}, + ) + return + except Exception: + logger.debug("torchvision.io.write_video failed, trying ffmpeg CLI.") + + # Fallback: pipe raw frames to ffmpeg + import subprocess, shutil + + ffmpeg = shutil.which("ffmpeg") + if ffmpeg is None: + raise RuntimeError("Neither torchvision.io.write_video nor ffmpeg CLI available.") + + cmd = [ + ffmpeg, "-y", + "-f", "rawvideo", "-vcodec", "rawvideo", + "-s", f"{outW}x{outH}", "-pix_fmt", "rgb24", + "-r", str(fps), "-i", "-", + "-c:v", "libx264", "-pix_fmt", "yuv420p", + "-crf", "23", "-preset", "fast", + str(output_path), + ] + proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) + proc.stdin.write(video_tensor.numpy().tobytes()) + proc.stdin.close() + _, stderr = proc.communicate() + if proc.returncode != 0: + raise RuntimeError(f"ffmpeg failed (rc={proc.returncode}): {stderr.decode()}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _arm_to_embodiment(arm: str) -> str: + """Map arm string to embodiment identifier.""" + return { + "left": "eva_left_arm", + "right": "eva_right_arm", + "both": "eva_bimanual", + }.get(arm, "eva_bimanual") + + +def _separate_numeric_and_image(episode_feats: dict): + """Split process_episode() output into numeric and image dicts. + + * Keys containing "images" are treated as image data + * Images are transposed from (T,C,H,W) to (T,H,W,C) because + process_episode() stores (T,C,H,W) while ZarrWriter + expects (T,H,W,3) for JPEG encoding. + * metadata.* keys are skipped (they are per-timestep constants + like embodiment id that are stored in zarr attrs instead). + """ + numeric_data: dict[str, np.ndarray] = {} + image_data: dict[str, np.ndarray] = {} + allowed_keys = set(DATASET_KEY_MAPPINGS.keys()) + + for key, value in episode_feats.items(): + if key.startswith("metadata."): + continue + + # Flatten one nested level (e.g., observations -> observations.images.front_img_1) + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + full_key = f"{key}.{nested_key}" + if full_key not in allowed_keys: + continue + + zarr_key = DATASET_KEY_MAPPINGS[full_key] + arr = np.asarray(nested_value) + + if "images" in full_key: + # Transpose (T,C,H,W) -> (T,H,W,C) when needed + if arr.ndim == 4 and arr.shape[1] in (1, 3, 4) and arr.shape[2] > arr.shape[1]: + arr = arr.transpose(0, 2, 3, 1) + image_data[zarr_key] = arr + else: + numeric_data[zarr_key] = arr + else: + if key not in allowed_keys: + continue + + zarr_key = DATASET_KEY_MAPPINGS[key] + numeric_data[zarr_key] = np.asarray(value) + + return numeric_data, image_data + + +_SPLIT_KEYS = {"obs_ee_pose", "cmd_ee_pose", "obs_joints", "cmd_joints"} + + +def _split_per_arm(numeric_data: dict, arm: str) -> dict: + """Split combined arm arrays into per-arm keys with gripper separated. + + Bimanual layout (T, 14): + [0:6] left xyz+ypr, [6] left gripper, + [7:13] right xyz+ypr, [13] right gripper. + Single-arm layout (T, 7): + [0:6] xyz+ypr, [6] gripper. + + Produces keys like ``left.obs_eepose`` (T,6), ``right.gripper`` (T,1), etc. + Gripper is taken from ``cmd_joints`` only (commanded state). + """ + out = {k: v for k, v in numeric_data.items() if k not in _SPLIT_KEYS} + + side = {"left": "left", "right": "right"}.get(arm) # None for "both" + + for base_key in _SPLIT_KEYS: + arr = numeric_data.get(base_key) + if arr is None: + continue + + # (TODO) aria to zarr change to use both instead of bimanual + if arm == "both": + if "joints" in base_key: + left_joints = arr[:, 0:6] + right_joints = arr[:, 7:13] + out[f"left.{base_key}"] = left_joints + out[f"right.{base_key}"] = right_joints + out["left.gripper"] = arr[:, 6:7] + out["right.gripper"] = arr[:, 13:14] + else: + left_ypr = arr[:, 3:6] + right_ypr = arr[:, 10:13] + left_quat = eva_reorientation(R.from_euler("ZYX", left_ypr, degrees=False).as_quat()) + right_quat = eva_reorientation(R.from_euler("ZYX", right_ypr, degrees=False).as_quat()) + left_quat = xyzw_to_wxyz(left_quat) + right_quat = xyzw_to_wxyz(right_quat) + left_translation = arr[:, 0:3] + right_translation = arr[:, 7:10] + left_translation_quat = np.concatenate([left_translation, left_quat], axis=-1) + right_translation_quat = np.concatenate([right_translation, right_quat], axis=-1) + out[f"left.{base_key}"] = left_translation_quat + out[f"right.{base_key}"] = right_translation_quat + out["left.gripper"] = arr[:, 6:7] + out["right.gripper"] = arr[:, 13:14] + else: + if "joints" in base_key: + out[f"{side}.{base_key}"] = arr[:, :6] + out[f"{side}.gripper"] = arr[:, 6:7] + else: + translation = arr[:, 0:3] + quat = eva_reorientation(R.from_euler("ZYX", arr[:, 3:6], degrees=False).as_quat()) + quat = xyzw_to_wxyz(quat) + translation_quat = np.concatenate([translation, quat], axis=-1) + out[f"{side}.{base_key}"] = translation_quat + gripper = arr[:, 6:7] + out[f"{side}.gripper"] = gripper + + return out +def _infer_total_frames( + numeric_data: dict[str, np.ndarray], image_data: dict[str, np.ndarray] +) -> int: + """Infer episode length from numeric/image arrays.""" + for arr in numeric_data.values(): + return int(len(arr)) + for arr in image_data.values(): + return int(len(arr)) + return 0 + + +def _build_example_annotations( + total_frames: int, description: str +) -> list[tuple[str, int, int]]: + """Create simple example language spans for demo/testing.""" + if total_frames <= 0: + return [] + + end_idx = total_frames - 1 + mid_idx = end_idx // 2 + task_text = description.strip() or "perform the demonstrated task" + + annotations: list[tuple[str, int, int]] = [ + (task_text, 0, end_idx), + ] + + if end_idx >= 3: + annotations.append(("reach and grasp", 0, mid_idx)) + if mid_idx + 1 <= end_idx: + annotations.append(("lift and place", mid_idx + 1, end_idx)) + + return annotations + + +# --------------------------------------------------------------------------- +# Single-episode conversion +# --------------------------------------------------------------------------- + +def convert_episode( + hdf5_path: Path, + zarr_episode_path: Path, + arm: str, + extrinsics_key: str, + fps: int, + description: str = "", + mp4_dir: Path | None = None, + chunk_timesteps: int = 100, + + example_annotations: bool = False, +) -> Path: + """Process one HDF5 file and write a .zarr episode. + + Returns the zarr episode path on success. + """ + hdf5_path = Path(hdf5_path) + zarr_episode_path = Path(zarr_episode_path) + + extrinsics = EXTRINSICS[extrinsics_key] + + episode_feats = EvaHD5Extractor.process_episode( + episode_path=hdf5_path, + arm=arm, + extrinsics=extrinsics, + ) + + numeric_data, image_data = _separate_numeric_and_image(episode_feats) + numeric_data = _split_per_arm(numeric_data, arm) + total_frames = _infer_total_frames(numeric_data, image_data) + + embodiment = _arm_to_embodiment(arm) + annotations = ( + _build_example_annotations(total_frames, description) + if example_annotations + else [] + ) + + ZarrWriter.create_and_write( + episode_path=zarr_episode_path, + numeric_data=numeric_data or None, + image_data=image_data or None, + embodiment=embodiment, + fps=fps, + task=description, + annotations=annotations, + chunk_timesteps=chunk_timesteps, + enable_sharding=True, + ) + + logger.info("Wrote zarr episode: %s", zarr_episode_path) + + # Optional preview MP4 + front_key = "images.front_img_1" + obs = episode_feats.get("observations") or {} + if mp4_dir is not None and front_key in obs: + images_tchw = np.asarray(obs[front_key]) + mp4_path = Path(mp4_dir) / f"{hdf5_path.stem}_video.mp4" + try: + logger.info("Saving preview MP4 to: %s", mp4_path) + save_preview_mp4(images_tchw, mp4_path, fps) + logger.info("Saved preview MP4: %s", mp4_path) + except Exception: + logger.warning( + "Failed to save preview MP4 at %s:\n%s", mp4_path, traceback.format_exc() + ) + + return zarr_episode_path + + +# --------------------------------------------------------------------------- +# main(args) – matches eva_to_lerobot.main(args) signature +# --------------------------------------------------------------------------- + +def main(args) -> None: + """Convert Eva HDF5 dataset to Zarr episodes. + + Parameters + ---------- + args : argparse.Namespace + Parsed command-line arguments (same shape as eva_to_lerobot). + """ + raw_path = Path(args.raw_path) + output_base = Path(args.output_dir) + + episode_list = sorted(raw_path.glob("*.hdf5")) + if not episode_list: + raise ValueError(f"No .hdf5 files found in {raw_path}") + + EvaHD5Extractor.check_format(episode_list, image_compressed=getattr(args, "image_compressed", False)) + + mp4_dir = output_base if getattr(args, "save_mp4", False) else None + chunk_timesteps = getattr(args, "chunk_timesteps", 100) + example_annotations = getattr( + args, + "example_language_annotations", + getattr(args, "example_annotations", False), + ) + + for hdf5_path in episode_list: + stem = hdf5_path.stem + zarr_episode_path = output_base / f"{stem}.zarr" + try: + convert_episode( + hdf5_path=hdf5_path, + zarr_episode_path=zarr_episode_path, + arm=args.arm, + extrinsics_key=getattr(args, "extrinsics_key", "x5Dec13_2"), + fps=getattr(args, "fps", 30), + description=getattr(args, "description", ""), + mp4_dir=mp4_dir, + chunk_timesteps=chunk_timesteps, + example_annotations=example_annotations, + ) + except Exception: + logger.error("Error converting %s:\n%s", hdf5_path, traceback.format_exc()) + continue + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def argument_parse(): + parser = argparse.ArgumentParser( + description="Convert Eva HDF5 dataset to Zarr episodes." + ) + parser.add_argument("--raw-path", type=Path, required=True, help="Directory containing raw HDF5 files.") + parser.add_argument("--fps", type=int, default=30, help="Frames per second.") + parser.add_argument("--output-dir", type=Path, required=True, help="Root output directory.") + parser.add_argument("--arm", type=str, choices=["left", "right", "both"], default="both") + parser.add_argument("--extrinsics-key", type=str, default="x5Dec13_2") + parser.add_argument("--image-compressed", type=str2bool, default=False) + parser.add_argument("--description", type=str, default="") + parser.add_argument("--save-mp4", type=str2bool, default=False) + parser.add_argument("--chunk-timesteps", type=int, default=100, help="Timesteps per zarr chunk for numeric arrays.") + parser.add_argument( + "--example-language-annotations", + type=str2bool, + default=False, + help="If true, write simple example language spans into each episode.", + ) + parser.add_argument("--debug", action="store_true") + parser.add_argument("--nproc", type=int, default=12) + parser.add_argument("--nthreads", type=int, default=2) + + return parser.parse_args() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main(argument_parse()) diff --git a/egomimic/scripts/eva_process/run_eva_conversion.py b/egomimic/scripts/eva_process/run_eva_conversion.py index 4be223a0..39ee009a 100644 --- a/egomimic/scripts/eva_process/run_eva_conversion.py +++ b/egomimic/scripts/eva_process/run_eva_conversion.py @@ -2,33 +2,48 @@ from __future__ import annotations import argparse +import contextlib +import csv import json import os import shutil +import sys import time import traceback import uuid from datetime import datetime, timezone from pathlib import Path -from typing import Iterator, Tuple +from typing import Any, Dict, Iterator, Tuple import ray -from eva_helper import lerobot_job +from cloudpathlib import S3Path +from ray.exceptions import OutOfMemoryError, RayTaskError, WorkerCrashedError +from eva_helper import lerobot_job, zarr_job + +from egomimic.utils.aws.aws_data_utils import ( + get_boto3_s3_client, + get_cloudpathlib_s3_client, + load_env, + upload_dir_to_s3, +) from egomimic.utils.aws.aws_sql import ( create_default_engine, episode_hash_to_table_row, + episode_table_to_df, update_episode, ) -RAW_ROOT = Path("/mnt/raw") -PROCESSED_ROOT = Path("/mnt/processed") -PROCESSED_LOCAL_ROOT = Path( - os.environ.get("PROCESSED_LOCAL_ROOT", "/mnt/processed") +RAW_REMOTE_PREFIX = os.environ.get("RAW_REMOTE_PREFIX", "s3://rldb/raw_v2/eva").rstrip("/") +PROCESSED_LOCAL_ROOT = Path(os.environ.get("PROCESSED_LOCAL_ROOT", "/home/ubuntu/processed")).resolve() +PROCESSED_REMOTE_PREFIX = os.environ.get("PROCESSED_REMOTE_PREFIX", "s3://rldb/processed_v3/eva").rstrip("/") +BUCKET = os.environ.get("BUCKET", "rldb") +LOG_ROOT = Path( + os.environ.get( + "EVA_CONVERSION_LOG_ROOT", + str(PROCESSED_LOCAL_ROOT / "eva_conversion_logs"), + ) ).resolve() -PROCESSED_REMOTE_PREFIX = os.environ.get( - "PROCESSED_REMOTE_PREFIX", "s3://rldb/processed_v2/eva" -).rstrip("/") DEFAULT_EXTRINSICS_KEY = "x5Dec13_2" @@ -77,14 +92,22 @@ def _load_extrinsics_key_from_json(meta_json: Path) -> str: return DEFAULT_EXTRINSICS_KEY -def iter_hdf5_bundles(root: Path) -> Iterator[Tuple[Path, str]]: - for data in sorted(root.glob("*.hdf5")): - name = data.stem - - meta_json = root / f"{name}_metadata.json" - extrinsics_key = _load_extrinsics_key_from_json(meta_json) - - yield data, extrinsics_key +def iter_hdf5_bundles_s3(root_s3: str) -> Iterator[Tuple[S3Path, str]]: + """Walk R2 for *.hdf5 files; load extrinsics_key from sidecar JSON if present.""" + s3_client = get_cloudpathlib_s3_client() + root = S3Path(root_s3, client=s3_client) + for hdf5 in sorted(root.glob("*.hdf5"), key=lambda p: p.name): + name = hdf5.stem + meta_json_s3 = root / f"{name}_metadata.json" + extrinsics_key = DEFAULT_EXTRINSICS_KEY + try: + if meta_json_s3.exists(): + obj = json.loads(meta_json_s3.read_text()) + if isinstance(obj, dict) and isinstance(obj.get("extrinsics_key"), str) and obj["extrinsics_key"]: + extrinsics_key = obj["extrinsics_key"] + except Exception: + pass + yield hdf5, extrinsics_key def infer_arm_from_robot_name(robot_name: str | None) -> str: @@ -107,163 +130,464 @@ def _load_episode_key(name: str) -> str | None: return None -@ray.remote(num_cpus=12) -def convert_one_bundle( - data_h5: str, +def _is_oom_exception(e: Exception) -> bool: + if isinstance(e, OutOfMemoryError): + return True + if isinstance(e, (RayTaskError, WorkerCrashedError)): + s = str(e).lower() + return ( + ("outofmemory" in s) + or ("out of memory" in s) + or ("oom" in s) + or ("killed" in s) + ) + s = str(e).lower() + return ("outofmemory" in s) or ("out of memory" in s) or ("oom" in s) + + +class _Tee: + def __init__(self, *streams): + self._streams = streams + + def write(self, data: str) -> int: + for s in self._streams: + s.write(data) + s.flush() + return len(data) + + def flush(self) -> None: + for s in self._streams: + s.flush() + + def isatty(self) -> bool: + return False + + +def _parse_s3_uri(uri: str, *, default_bucket: str | None = None) -> tuple[str, str]: + """ + Parse s3 URI or key prefix. + - "s3://bucket/prefix" -> ("bucket", "prefix") + - "prefix" -> (default_bucket, "prefix") + """ + uri = (uri or "").strip() + if uri.startswith("s3://"): + rest = uri[len("s3://"):] + bucket, _, key_prefix = rest.partition("/") + return bucket, key_prefix.strip("/") + if default_bucket is None: + raise ValueError(f"Expected s3://... but got '{uri}' and no default_bucket provided") + return default_bucket, uri.strip("/") + + +def convert_one_bundle_impl( + data_h5_s3: str, out_dir: str, + s3_processed_dir: str, dataset_name: str, arm: str, description: str, extrinsics_key: str, + backend: str = "zarr", ) -> tuple[str, str, int]: - stem = Path(data_h5).stem + s3_client = get_boto3_s3_client() + hdf5_s3 = S3Path(data_h5_s3) + stem = hdf5_s3.stem + + LOG_ROOT.mkdir(parents=True, exist_ok=True) + log_path = LOG_ROOT / f"{stem}-{uuid.uuid4().hex[:8]}.log" + tmp_dir = Path.home() / "temp_eva_processing" / f"{stem}-{uuid.uuid4().hex[:6]}" tmp_dir.mkdir(parents=True, exist_ok=True) - targets = [ - Path(data_h5).resolve(strict=True), - ] + with log_path.open("a", encoding="utf-8") as log_fh: + tee_out = _Tee(sys.stdout, log_fh) + tee_err = _Tee(sys.stderr, log_fh) + with contextlib.redirect_stdout(tee_out), contextlib.redirect_stderr(tee_err): + print(f"[LOG] {stem}: {log_path}", flush=True) - for t in targets: - if not ensure_path_ready(t): - print(f"[ERR] missing {t}", flush=True) - shutil.rmtree(tmp_dir, ignore_errors=True) - return "", "", -1 - link = tmp_dir / t.name - try: - os.symlink(t, link, target_is_directory=t.is_dir()) - except FileExistsError: - pass + raw_bucket, raw_prefix = _parse_s3_uri(RAW_REMOTE_PREFIX, default_bucket=BUCKET) + raw_root = S3Path(RAW_REMOTE_PREFIX) - ds_parent = Path(out_dir) - ds_parent.mkdir(parents=True, exist_ok=True) - ds_path = ds_parent / dataset_name + rel = hdf5_s3.relative_to(raw_root).as_posix() + t_key = f"{raw_prefix.rstrip('/')}/{rel}".strip("/") + local_hdf5 = tmp_dir / hdf5_s3.name + try: + s3_client.download_file(raw_bucket, t_key, str(local_hdf5)) + except Exception as e: + print(f"[ERR] aws download failed for {data_h5_s3}: {e}", flush=True) + shutil.rmtree(tmp_dir, ignore_errors=True) + return "", "", -1 - try: - print( - f"[INFO] Converting: {stem} → {ds_path} (arm={arm}, extrinsics_key={extrinsics_key})", - flush=True, - ) - lerobot_job( - raw_path=str(tmp_dir), - output_dir=str(ds_parent), - dataset_name=dataset_name, - arm=arm, - description=description or "", - extrinsics_key=extrinsics_key, - ) + ds_parent = Path(out_dir) + ds_parent.mkdir(parents=True, exist_ok=True) + ds_path = ds_parent / dataset_name - frames = -1 - info = ds_path / "meta/info.json" - if info.exists(): try: - meta = json.loads(info.read_text()) - frames = int(meta.get("total_frames", -1)) - except Exception: - frames = -1 - - mp4_candidates = list(ds_parent.glob(f"*{stem}*_video.mp4")) + list( - ds_path.glob("**/*_video.mp4") - ) - mp4_str = str(mp4_candidates[0]) if mp4_candidates else "" + print( + f"[INFO] Converting: {stem} → {ds_path} (arm={arm}, extrinsics_key={extrinsics_key})", + flush=True, + ) + job_kwargs = dict( + raw_path=str(tmp_dir), + output_dir=str(ds_parent), + dataset_name=dataset_name, + arm=arm, + description=description or "", + extrinsics_key=extrinsics_key, + ) + if backend == "zarr": + zarr_job(**job_kwargs) + else: + lerobot_job(**job_kwargs) - return str(ds_path), mp4_str, frames + frames = -1 + mp4_str = "" + path_for_sql = str(ds_path) + zarr_store_path = None + if backend == "zarr": + zarr_store_path = ds_parent / f"{stem}.zarr" + info = zarr_store_path / "zarr.json" + print(f"[DEBUG] Zarr metadata path: {info}", flush=True) + if info.exists(): + try: + meta = json.loads(info.read_text()) + print(f"[DEBUG] Zarr metadata keys: {list(meta.keys())}", flush=True) + frames = int(meta.get("attributes", {}).get("total_frames", -1)) + except Exception as e: + print(f"[ERR] Failed to parse zarr metadata {info}: {e}", flush=True) + frames = -1 + else: + print(f"[ERR] Zarr metadata not found: {info}", flush=True) + path_for_sql = f"{PROCESSED_REMOTE_PREFIX}/{stem}.zarr" + else: + info = ds_path / "meta/info.json" + if info.exists(): + try: + meta = json.loads(info.read_text()) + frames = int(meta.get("total_frames", -1)) + except Exception: + frames = -1 + + mp4_candidates = list(ds_parent.glob(f"*{stem}*_video.mp4")) + list( + ds_path.glob("**/*_video.mp4") + ) + mp4_str = str(mp4_candidates[0]) if mp4_candidates else "" + + try: + out_bucket, out_prefix = _parse_s3_uri(s3_processed_dir, default_bucket=BUCKET) + if backend == "zarr" and zarr_store_path is not None: + ds_s3_prefix = f"{out_prefix.rstrip('/')}/{stem}.zarr".strip("/") + upload_dir_to_s3(str(zarr_store_path), out_bucket, prefix=ds_s3_prefix) + shutil.rmtree(str(zarr_store_path), ignore_errors=True) + print(f"[CLEANUP] Removed local zarr store: {zarr_store_path}", flush=True) + else: + ds_rel = ds_path.resolve().relative_to(PROCESSED_LOCAL_ROOT).as_posix() + ds_s3_prefix = f"{out_prefix.rstrip('/')}/{ds_rel}".strip("/") + upload_dir_to_s3(str(ds_path), out_bucket, prefix=ds_s3_prefix) + shutil.rmtree(str(ds_path), ignore_errors=True) + print(f"[CLEANUP] Removed local dataset dir: {ds_path}", flush=True) + if mp4_str: + mp4_obj = Path(mp4_str) + if mp4_obj.exists(): + mp4_rel = mp4_obj.resolve().relative_to(PROCESSED_LOCAL_ROOT).as_posix() + mp4_s3_key = f"{out_prefix.rstrip('/')}/{mp4_rel}".strip("/") + s3_client.upload_file(str(mp4_obj), out_bucket, mp4_s3_key) + mp4_obj.unlink(missing_ok=True) + print(f"[CLEANUP] Removed local mp4: {mp4_obj}", flush=True) + except Exception as e: + print(f"[ERR] Failed to upload {ds_path} to S3: {e}", flush=True) + return path_for_sql, mp4_str, -2 + + return path_for_sql, mp4_str, frames + + except Exception as e: + err_msg = f"[FAIL] {stem}: {e}\n{traceback.format_exc()}" + print(err_msg, flush=True) + return path_for_sql, "", -1 + finally: + shutil.rmtree(tmp_dir, ignore_errors=True) + + +@ray.remote(num_cpus=2, resources={"eva_small": 1}) +def convert_one_bundle_small(*args, **kwargs): + return convert_one_bundle_impl(*args, **kwargs) + + +@ray.remote(num_cpus=8, resources={"eva_big": 1}) +def convert_one_bundle_big(*args, **kwargs): + return convert_one_bundle_impl(*args, **kwargs) + + +def launch( + dry: bool = False, + skip_if_done: bool = False, + backend: str = "zarr", + episode_hashes: list[str] | None = None, +): + engine = create_default_engine() + pending: Dict[ray.ObjectRef, Dict[str, Any]] = {} + benchmark_rows = [] - except Exception as e: - err_msg = f"[FAIL] {stem}: {e}\n{traceback.format_exc()}" - print(err_msg, flush=True) - return str(ds_path), "", -1 - finally: - shutil.rmtree(tmp_dir, ignore_errors=True) + df = episode_table_to_df(engine) + for hdf5_s3, extrinsics_key in iter_hdf5_bundles_s3(RAW_REMOTE_PREFIX): + name = hdf5_s3.stem + episode_key = _load_episode_key(name) -def launch(dry: bool = False, skip_if_done: bool = False): - engine = create_default_engine() - pending: dict = {} + row_match = df[df["episode_hash"] == episode_key] + if len(row_match) == 1: + row = row_match.iloc[0] + elif len(row_match) > 1: + print("[WARNING] Duplicate episode hash", flush=True) + row = row_match.iloc[0] + else: + row = None - for data_h5, extrinsics_key in iter_hdf5_bundles(RAW_ROOT): - name = data_h5.stem - episode_key = _load_episode_key(name) if not episode_key: - print(f"[SKIP] {name}: could not derive DB episode key") + print(f"[SKIP] {name}: could not derive DB episode key", flush=True) + continue + + if episode_hashes is not None and episode_key not in episode_hashes: + print( + f"[SKIP] {name}: episode_key '{episode_key}' not in provided episode_hashes list", + flush=True, + ) + print(f"[SKIP] {name}: no matching row in SQL (app.episodes)", flush=True) continue - row = episode_hash_to_table_row(engine, episode_key) if row is None: - print(f"[SKIP] {name}: no matching row in SQL (app.episodes)") + print(f"[SKIP] {name}: no matching row in SQL (app.episodes)", flush=True) continue - processed_path = (row.processed_path or "").strip() + if backend == "zarr": + processed_path = (row.zarr_processed_path or "").strip() + processing_error = row.zarr_processing_error + path_field_name = "zarr_processed_path" + else: + processed_path = (row.processed_path or "").strip() + processing_error = row.processing_error + path_field_name = "processed_path" + if skip_if_done and len(processed_path) > 0: - print(f"[SKIP] {name}: already has processed_path='{processed_path}'") + print(f"[SKIP] {name}: already has {path_field_name}='{processed_path}'", flush=True) + continue + + if processing_error != "": + print( + f"[INFO] skipping episode hash: {row.episode_hash} due to processing error", + flush=True, + ) continue if row.is_deleted: - print(f"[SKIP] {name}: episode marked as deleted in SQL") + print(f"[SKIP] {name}: episode marked as deleted in SQL", flush=True) continue + print(f"[INFO] processing {name}: episode_key={episode_key}", flush=True) + arm = infer_arm_from_robot_name(getattr(row, "robot_name", None)) - dataset_name = episode_key - out_dir = PROCESSED_ROOT + dataset_name = hdf5_s3.stem + out_dir = PROCESSED_LOCAL_ROOT + s3out_dir = PROCESSED_REMOTE_PREFIX description = row.task_description or "" if dry: - ds_path = (PROCESSED_ROOT / dataset_name).resolve() - mp4_candidate = PROCESSED_ROOT / f"{name}_video.mp4" - mapped_ds = _map_processed_local_to_remote(ds_path) + ds_path = (PROCESSED_LOCAL_ROOT / dataset_name).resolve() + mp4_candidate = PROCESSED_LOCAL_ROOT / f"{name}_video.mp4" + if backend == "zarr": + # Zarr stored flat under eva: prefix/.zarr + mapped_ds = f"{PROCESSED_REMOTE_PREFIX}/{name}.zarr" + else: + mapped_ds = _map_processed_local_to_remote(ds_path) mapped_mp4 = _map_processed_local_to_remote(mp4_candidate) + path_field_name = "zarr_processed_path" if backend == "zarr" else "processed_path" + mp4_field_name = "zarr_mp4_path" if backend == "zarr" else "mp4_path" print( f"[DRY] {name}: arm={arm} | out_dir={out_dir}/{dataset_name}\n" f" desc-bytes={len(description.encode('utf-8'))}\n" f" extrinsics_key={extrinsics_key}\n" f" would write to SQL:\n" - f" processed_path={mapped_ds}\n" - f" mp4_path={mapped_mp4}" + f" {path_field_name}={mapped_ds}\n" + f" {mp4_field_name}={mapped_mp4}", + flush=True, ) continue - ref = convert_one_bundle.remote( - str(data_h5), + args_tuple = ( + str(hdf5_s3), str(out_dir), + str(s3out_dir), dataset_name, arm, description, extrinsics_key, + backend, ) - pending[ref] = (episode_key, dataset_name) + + start_time = time.time() + ref = convert_one_bundle_small.remote(*args_tuple) + pending[ref] = { + "episode_key": episode_key, + "dataset_name": dataset_name, + "start_time": start_time, + "size": "small", + "args": args_tuple, + "backend": backend, + } if dry or not pending: return while pending: - done_refs, _ = ray.wait(list(pending), num_returns=1) + done_refs, _ = ray.wait(list(pending.keys()), num_returns=1) ref = done_refs[0] - ds_path, mp4_path, frames = ray.get(ref) - episode_key, _dataset_name = pending.pop(ref) + info = pending.pop(ref) + + episode_key = info["episode_key"] + start_time = info["start_time"] + duration_sec = time.time() - start_time row = episode_hash_to_table_row(engine, episode_key) if row is None: - print(f"[WARN] Episode {episode_key}: row disappeared before update?") + print(f"[WARN] Episode {episode_key}: row disappeared before update?", flush=True) continue - row.num_frames = int(frames) if frames is not None else -1 - if row.num_frames > 0: - row.processed_path = _map_processed_local_to_remote(ds_path) - row.mp4_path = _map_processed_local_to_remote(mp4_path) - row.processing_error = "" - else: - row.processed_path = "" - row.mp4_path = "" - row.processing_error = "Zero Frames" - try: + ds_path, mp4_path, frames = ray.get(ref) + backend = info.get("backend", "zarr") + + row.num_frames = int(frames) if frames is not None else -1 + if backend == "zarr": + mapped_ds = ds_path + else: + mapped_ds = _map_processed_local_to_remote(ds_path) + mapped_mp4 = _map_processed_local_to_remote(mp4_path) + + if backend == "zarr": + if row.num_frames > 0: + row.zarr_processed_path = mapped_ds + row.zarr_mp4_path = mapped_mp4 + row.zarr_processing_error = "" + elif row.num_frames == -2: + row.zarr_processed_path = "" + row.zarr_mp4_path = "" + row.zarr_processing_error = "Upload Failed" + elif row.num_frames == -1: + row.zarr_processed_path = "" + row.zarr_mp4_path = "" + row.zarr_processing_error = "Zero Frames" + else: + row.zarr_processed_path = "" + row.zarr_mp4_path = "" + row.zarr_processing_error = "Conversion Failed Unhandled Error" + path_value = row.zarr_processed_path + else: + if row.num_frames > 0: + row.processed_path = mapped_ds + row.mp4_path = mapped_mp4 + row.processing_error = "" + elif row.num_frames == -2: + row.processed_path = "" + row.mp4_path = "" + row.processing_error = "Upload Failed" + elif row.num_frames == -1: + row.processed_path = "" + row.mp4_path = "" + row.processing_error = "Zero Frames" + else: + row.processed_path = "" + row.mp4_path = "" + row.processing_error = "Conversion Failed Unhandled Error" + path_value = row.processed_path + update_episode(engine, row) + path_field_name = "zarr_processed_path" if backend == "zarr" else "processed_path" print( f"[OK] Updated SQL for {episode_key}: " - f"processed_path={row.processed_path}, num_frames={row.num_frames}" + f"{path_field_name}={path_value}, num_frames={row.num_frames}, " + f"duration_sec={duration_sec:.2f}", + flush=True, + ) + + if row.num_frames > 0 and path_value: + mp4_val = row.zarr_mp4_path if backend == "zarr" else row.mp4_path + benchmark_rows.append( + { + "episode_key": episode_key, + "processed_path": path_value, + "mp4_path": mp4_val, + "num_frames": row.num_frames, + "duration_sec": duration_sec, + } + ) + + except Exception as e: + backend = info.get("backend", "zarr") + if _is_oom_exception(e) and info.get("size") == "small": + print( + f"[OOM] Episode {episode_key} failed on SMALL. Retrying on BIG...", + flush=True, + ) + args_tuple = info["args"] + ref2 = convert_one_bundle_big.remote(*args_tuple) + pending[ref2] = { + **info, + "start_time": time.time(), + "size": "big", + } + continue + + print( + f"[FAIL] Episode {episode_key} task failed ({info.get('size', '?')}): " + f"{type(e).__name__}: {e}", + flush=True, + ) + + row.num_frames = -1 + error_msg = f"{type(e).__name__}: {e}" + path_field_name = "zarr_processed_path" if backend == "zarr" else "processed_path" + + if backend == "zarr": + row.zarr_mp4_path = "" + row.zarr_processed_path = "" + row.zarr_processing_error = error_msg + else: + row.mp4_path = "" + row.processed_path = "" + row.processing_error = error_msg + + try: + update_episode(engine, row) + print( + f"[FAIL] Marked SQL failed for {episode_key} (cleared {path_field_name})", + flush=True, + ) + except Exception as ee: + print(f"[ERR] SQL update failed for failed episode {episode_key}: {ee}", flush=True) + + if benchmark_rows: + timing_file = Path("./eva_conversion_timings.csv") + file_exists = timing_file.exists() + fieldnames = [ + "episode_key", + "processed_path", + "mp4_path", + "num_frames", + "duration_sec", + ] + try: + with timing_file.open("a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if not file_exists: + writer.writeheader() + for bench_row in benchmark_rows: + writer.writerow(bench_row) + print( + f"[BENCH] wrote {len(benchmark_rows)} entries → {timing_file.resolve()}", + flush=True, ) except Exception as e: - print(f"[ERR] SQL update failed for {episode_key}: {e}") + print(f"[ERR] Failed to write benchmark CSV {timing_file}: {e}", flush=True) def main(): @@ -274,10 +598,59 @@ def main(): action="store_true", help="Skip episodes that already have a processed_path in SQL", ) + p.add_argument( + "--backend", + type=str, + choices=["lerobot", "zarr"], + default="zarr", + help="Output backend: 'zarr' (default) or 'lerobot'", + ) + p.add_argument( + "--ray-address", default="auto", help="Ray cluster address (default: auto)" + ) + p.add_argument( + "--episode-hash", + action="append", + dest="episode_hashes", + help="Episode hash to process. Can be specified multiple times to process multiple episodes.", + ) + p.add_argument("--debug", action="store_true") args = p.parse_args() - ray.init(address="auto") - launch(dry=args.dry_run, skip_if_done=args.skip_if_done) + env_vars = {} + load_env() + for k in [ + "R2_ACCESS_KEY_ID", + "R2_SECRET_ACCESS_KEY", + "R2_SESSION_TOKEN", + "R2_ENDPOINT_URL", + ]: + v = os.environ.get(k) + if v: + env_vars[k] = v + + if args.debug: + runtime_env = { + "working_dir": "/home/ubuntu/EgoVerse", + "excludes": [ + "**/.git/**", + "external/openpi/third_party/aloha/**", + "**/*.pack", + "**/__pycache__/**", + "external/openpi/**", + ], + } + else: + runtime_env = {} + runtime_env["env_vars"] = env_vars + + ray.init(address=args.ray_address, runtime_env=runtime_env) + launch( + dry=args.dry_run, + skip_if_done=args.skip_if_done, + backend=args.backend, + episode_hashes=args.episode_hashes, + ) if __name__ == "__main__": diff --git a/egomimic/scripts/eva_process/test_zarr_path_update.py b/egomimic/scripts/eva_process/test_zarr_path_update.py new file mode 100644 index 00000000..e69de29b diff --git a/egomimic/scripts/eva_process/zarr_utils.py b/egomimic/scripts/eva_process/zarr_utils.py index 6154c7f4..bd481dd9 100644 --- a/egomimic/scripts/eva_process/zarr_utils.py +++ b/egomimic/scripts/eva_process/zarr_utils.py @@ -79,7 +79,6 @@ def assert_fk_matches_eepose( skip_first: int = 1, zero_eps: float = 1e-12, ): - # prestack handling base0 = ( actions_base_cartesian[:, 0] if actions_base_cartesian.ndim == 3 @@ -425,19 +424,6 @@ def build_future_windows(arr: np.ndarray, horizon: int) -> np.ndarray: return out -def prestack_with_mode( - arr: np.ndarray, - horizon: int, - chunk_length: int, - mode: str = "linear", -) -> np.ndarray: - windows = build_future_windows(arr, horizon) - if mode == "euler": - return interpolate_arr_euler(windows, chunk_length) - else: - return interpolate_arr(windows, chunk_length) - - def joint_to_pose(pose, arm, left_extrinsics=None, right_extrinsics=None, no_rot=False): """ pose: (T, ACTION_DIM) @@ -570,7 +556,7 @@ def joint_to_pose(pose, arm, left_extrinsics=None, right_extrinsics=None, no_rot fk_positions = np.concatenate([fk_positions, fk_ypr, gripper], axis=1) - return fk_positions, base_fk_positions + return fk_positions, base_fk_positions.astype(np.float32) class EvaHD5Extractor: @@ -578,7 +564,7 @@ class EvaHD5Extractor: @staticmethod def process_episode( - episode_path, arm, extrinsics, prestack=False, low_res=False, no_rot=False + episode_path, arm, extrinsics, low_res=False, no_rot=False ): """ Extracts all feature keys from a given episode and returns as a dictionary @@ -590,8 +576,6 @@ def process_episode( String for which arm to add data for extrinsics : np.array camera extrinsic, It is a tuple of (left_extrinsics, right_extrinsics) if arm is both - prestack : bool - prestack the future actions or not Returns ------- episode_feats : dict @@ -680,17 +664,12 @@ def process_episode( ) # actions - joint_actions, cartesian_actions, base_cartesian_actions = ( - EvaHD5Extractor.get_action( - episode["action"][:], - arm=arm, - prestack=prestack, - HORIZON=HORIZON_BASE, - CHUNK_LENGTH=CHUNK_LENGTH_BASE, - left_extrinsics=left_extrinsics, - right_extrinsics=right_extrinsics, - no_rot=no_rot, - ) + joint_actions, cartesian_actions, base_cartesian_actions = EvaHD5Extractor.get_action( + episode["action"][:], + arm=arm, + left_extrinsics=left_extrinsics, + right_extrinsics=right_extrinsics, + no_rot=no_rot, ) # dbg = print_fk_eepose_diffs( # base_cartesian_actions, @@ -709,13 +688,7 @@ def process_episode( episode_feats["actions_cartesian"] = cartesian_actions episode_feats["actions_base_cartesian"] = base_cartesian_actions - episode_feats["actions_eef_cartesian"] = EvaHD5Extractor.get_eef_action( - actions_cartesian_base=base_cartesian_actions, - arm=arm, - ref_index=0, - ) - - episode_feats["observations"]["state.joint_positions"] = episode_feats[ + episode_feats["observations"][f"state.joint_positions"] = episode_feats[ "observations" ]["state.joint_positions"][:, joint_start:joint_end] @@ -743,9 +716,6 @@ def process_episode( def get_action( actions: np.array, arm: str, - prestack: bool = False, - HORIZON: int = HORIZON_BASE, - CHUNK_LENGTH: int = CHUNK_LENGTH_BASE, left_extrinsics=None, right_extrinsics=None, no_rot: bool = False, @@ -772,148 +742,8 @@ def get_action( joint_actions = joint_actions[:, joint_start:joint_end] - if prestack: - horizon = HORIZON - - joint_actions = prestack_with_mode( - np.asarray(joint_actions), - horizon=horizon, - chunk_length=CHUNK_LENGTH, - mode="linear", - ) - - cart_np = np.asarray(cartesian_actions) - base_np = np.asarray(base_cartesian_actions) - - if arm == "both": - left_base = base_np[:, :7] - right_base = base_np[:, 7:14] - left_cam = cart_np[:, :7] - right_cam = cart_np[:, 7:14] - - left_base = prestack_with_mode( - left_base, - horizon=horizon, - chunk_length=CHUNK_LENGTH, - mode="euler", - ) - right_base = prestack_with_mode( - right_base, - horizon=horizon, - chunk_length=CHUNK_LENGTH, - mode="euler", - ) - left_cam = prestack_with_mode( - left_cam, - horizon=horizon, - chunk_length=CHUNK_LENGTH, - mode="euler", - ) - right_cam = prestack_with_mode( - right_cam, - horizon=horizon, - chunk_length=CHUNK_LENGTH, - mode="euler", - ) - - base_cartesian_actions = np.concatenate( - [left_base, right_base], axis=-1 - ) - cartesian_actions = np.concatenate([left_cam, right_cam], axis=-1) - else: - base_cartesian_actions = prestack_with_mode( - base_np, - horizon=horizon, - chunk_length=CHUNK_LENGTH, - mode="euler", - ) - cartesian_actions = prestack_with_mode( - cart_np, - horizon=horizon, - chunk_length=CHUNK_LENGTH, - mode="euler", - ) - return (joint_actions, cartesian_actions, base_cartesian_actions) - @staticmethod - def get_eef_action( - actions_cartesian_base: np.ndarray, - arm: str, - ref_index: int = 0, - ) -> np.ndarray: - """Compute relative EEF actions from base-frame Cartesian actions. - - Supports both prestacked (N, S, D) and non-prestacked (N, D) inputs. - For prestacked: each row has S future actions, relative to ref_index action. - For non-prestacked: computes relative pose from previous frame (or identity for first). - - Args: - actions_cartesian_base: Shape (N, S, D) or (N, D) where D=7 per arm (xyz + ypr + gripper) - arm: "left", "right", or "both" - ref_index: Reference index for relative transform (only used for prestacked) - - Returns: - Relative actions with same shape as input - """ - if arm == "both": - left_base = actions_cartesian_base[..., :7] - right_base = actions_cartesian_base[..., 7:14] - - left_rel = EvaHD5Extractor.get_eef_action( - left_base, arm="left", ref_index=ref_index - ) - right_rel = EvaHD5Extractor.get_eef_action( - right_base, arm="right", ref_index=ref_index - ) - return np.concatenate([left_rel, right_rel], axis=-1) - - # Handle both 2D (N, D) and 3D (N, S, D) inputs - is_2d = actions_cartesian_base.ndim == 2 - if is_2d: - # Add sequence dimension: (N, D) -> (N, 1, D) - actions_cartesian_base = actions_cartesian_base[:, np.newaxis, :] - - N, S, D = actions_cartesian_base.shape - - p = actions_cartesian_base[..., :3] # (N, S, 3) - ypr = actions_cartesian_base[..., 3:6] # (N, S, 3) - g = actions_cartesian_base[..., 6:7] # (N, S, 1) - - ypr_flat = ypr.reshape(-1, 3) # (N*S, 3) - R_flat = R.from_euler("ZYX", ypr_flat, degrees=False).as_matrix() # (N*S, 3, 3) - R_seq = R_flat.reshape(N, S, 3, 3) # (N, S, 3, 3) - - T_seq = np.zeros((N, S, 4, 4), dtype=np.float32) - T_seq[..., :3, :3] = R_seq - T_seq[..., :3, 3] = p - T_seq[..., 3, 3] = 1.0 - - T0 = T_seq[:, ref_index, :, :] # (N, 4, 4) - T0_inv = np.linalg.inv(T0) # (N, 4, 4) - - T_rel = T0_inv[:, None, :, :] @ T_seq # (N, S, 4, 4) - - p_rel = T_rel[..., :3, 3] # (N, S, 3) - R_rel = T_rel[..., :3, :3] # (N, S, 3, 3) - - R_rel_flat = R_rel.reshape(-1, 3, 3) # (N*S, 3, 3) - ypr_rel_flat = R.from_matrix(R_rel_flat).as_euler( - "ZYX", degrees=False - ) # (N*S, 3) - ypr_rel = ypr_rel_flat.reshape(N, S, 3) # (N, S, 3) - - actions_rel = np.empty_like(actions_cartesian_base) - actions_rel[..., :3] = p_rel - actions_rel[..., 3:6] = ypr_rel - actions_rel[..., 6:7] = g # keep gripper as-is - - # Remove sequence dimension if input was 2D - if is_2d: - actions_rel = actions_rel[:, 0, :] # (N, 1, D) -> (N, D) - - return actions_rel - @staticmethod def get_ee_pose( qpos: np.array, @@ -1065,7 +895,6 @@ def extract_episode_frames( image_compressed: bool, arm: str, extrinsics: dict, - prestack: bool = False, ) -> list[dict[str, torch.Tensor]]: """ Extract frames from an episode by processing it and using the feature dictionary. @@ -1082,8 +911,6 @@ def extract_episode_frames( The arm to process (e.g., 'left', 'right', or 'both'). extrinsics : dict Camera extrinsics for the episode. - prestack : bool, optional - Whether to precompute action chunks, by default False. Returns ------- @@ -1092,7 +919,7 @@ def extract_episode_frames( """ frames = [] episode_feats = EvaHD5Extractor.process_episode( - episode_path, arm=arm, extrinsics=extrinsics, prestack=prestack + episode_path, arm=arm, extrinsics=extrinsics ) num_frames = next(iter(episode_feats["observations"].values())).shape[0] for frame_idx in range(num_frames): @@ -1183,8 +1010,7 @@ def define_features( dim_names = ["channel", "height", "width"] elif "actions" in key and len(value[0].shape) > 1: shape = value[0].shape - dim_names = ["chunk_length", "action_dim"] - dtype = f"prestacked_{str(value.dtype)}" + dim_names = ["action_dim"] else: shape = value[0].shape dim_names = [f"dim_{i}" for i in range(len(shape))] @@ -1197,8 +1023,7 @@ def define_features( dtype = str(value.dtype) shape = tuple(value[0].size()) if "actions" in key and len(tuple(value[0].size())) > 1: - dim_names = ["chunk_length", "action_dim"] - dtype = f"prestacked_{str(value.dtype)}" + dim_names = ["action_dim"] else: dim_names = [f"dim_{i}" for i in range(len(shape))] dim_names = [f"dim_{i}" for i in range(len(shape))] diff --git a/egomimic/scripts/load_episode.sh b/egomimic/scripts/load_episode.sh new file mode 100755 index 00000000..b405b8bc --- /dev/null +++ b/egomimic/scripts/load_episode.sh @@ -0,0 +1,23 @@ +set -a +source ~/.egoverse_env +set +a + +export AWS_ACCESS_KEY_ID="$R2_ACCESS_KEY_ID" +export AWS_SECRET_ACCESS_KEY="$R2_SECRET_ACCESS_KEY" +export AWS_DEFAULT_REGION="auto" +export AWS_REGION="auto" + +ID=1764285228498 +path="/home/ubuntu/download_aria" + +# s5cmd --endpoint-url "$R2_ENDPOINT_URL" cp "s3://rldb/raw_v2/aria/${ID}.json" $path +# s5cmd --endpoint-url "$R2_ENDPOINT_URL" cp "s3://rldb/raw_v2/aria/${ID}.vrs" $path +# s5cmd --endpoint-url "$R2_ENDPOINT_URL" cp "s3://rldb/raw_v2/aria/${ID}_metadata.json" $path + +s5cmd --endpoint-url "$R2_ENDPOINT_URL" sync \ + "s3://rldb/processed_v3/proc_test_aria/${ID}.zarr/**" \ + "${path}/${ID}.zarr/" + +# s5cmd --endpoint-url "$R2_ENDPOINT_URL" sync \ +# "s3://rldb/processed_v3/test_eva/1764215784190.zarr/**" \ +# "${path}/1764215784190.zarr/" \ No newline at end of file diff --git a/egomimic/utils/aws/aws_data_utils.py b/egomimic/utils/aws/aws_data_utils.py index e1318c82..981c6820 100644 --- a/egomimic/utils/aws/aws_data_utils.py +++ b/egomimic/utils/aws/aws_data_utils.py @@ -1,11 +1,71 @@ from __future__ import annotations +import cloudpathlib +from pathlib import Path import os from pathlib import Path import boto3 from boto3.s3.transfer import TransferConfig +from botocore.config import Config +def load_env(path="/home/ubuntu/.egoverse_env"): + p = Path(path) + if not p.exists(): + return + for line in p.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + k, v = line.split("=", 1) + os.environ.setdefault(k, v.strip().strip("'").strip('"')) + +def get_cloudpathlib_s3_client(): + load_env() + endpoint_url = "https://1beb594fb475d71c4420f7b693524e19.r2.cloudflarestorage.com" + r2_access_key_id = os.environ.get("R2_ACCESS_KEY_ID") or os.environ.get( + "AWS_ACCESS_KEY_ID" + ) + r2_secret_access_key = os.environ.get("R2_SECRET_ACCESS_KEY") or os.environ.get( + "AWS_SECRET_ACCESS_KEY" + ) + r2_session_token = os.environ.get("R2_SESSION_TOKEN") or os.environ.get( + "AWS_SESSION_TOKEN" + ) + s3_boto3_session = boto3.session.Session( + region_name="auto", + aws_access_key_id=r2_access_key_id, + aws_secret_access_key=r2_secret_access_key, + aws_session_token=r2_session_token, + ) + + s3_client = cloudpathlib.S3Client( + endpoint_url=endpoint_url, + boto3_session=s3_boto3_session, + ) + for key in ( + "AWS_ACCESS_KEY_ID", + "AWS_SECRET_ACCESS_KEY", + "AWS_SESSION_TOKEN", + "AWS_SECURITY_TOKEN", + ): + os.environ.pop(key, None) + return s3_client + +def get_boto3_s3_client(): + load_env() + endpoint_url = "https://1beb594fb475d71c4420f7b693524e19.r2.cloudflarestorage.com" + access_key_id = os.environ["R2_ACCESS_KEY_ID"] + secret_access_key = os.environ["R2_SECRET_ACCESS_KEY"] + s3 = boto3.client( + "s3", + endpoint_url=endpoint_url, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + region_name="auto", # R2 ignores region; "auto" is common + config=Config(signature_version="s3v4"), + ) + return s3 def load_env(path="~/.egoverse_env"): p = Path(path).expanduser() @@ -35,7 +95,7 @@ def s3_sync_to_local(bucket: str, key_prefix: str, local_dir: str | Path) -> Non local_dir.mkdir(parents=True, exist_ok=True) # Create the client inside the function so this works cleanly on Ray workers. - s3 = boto3.client("s3") + s3 = get_boto3_s3_client() config = TransferConfig( max_concurrency=16, @@ -60,11 +120,8 @@ def s3_sync_to_local(bucket: str, key_prefix: str, local_dir: str | Path) -> Non s3.download_file(bucket, key, str(dest), Config=config) - -def upload_dir_to_s3( - local_dir: str, bucket: str, prefix: str = "", concurrency: int = 32 -): - s3 = boto3.client("s3") +def upload_dir_to_s3(local_dir: str, bucket: str, prefix: str = "", concurrency: int = 32): + s3 = get_boto3_s3_client() cfg = TransferConfig( max_concurrency=concurrency, multipart_threshold=64 * 1024 * 1024, # 64MB @@ -82,3 +139,4 @@ def upload_dir_to_s3( rel = lp.relative_to(local_dir).as_posix() key = f"{prefix}/{rel}" if prefix else rel s3.upload_file(str(lp), bucket, key, Config=cfg) + print(f"Uploaded files to S3 {prefix}") diff --git a/egomimic/utils/aws/sql_tutorial.ipynb b/egomimic/utils/aws/sql_tutorial.ipynb index 92211951..99a37616 100644 --- a/egomimic/utils/aws/sql_tutorial.ipynb +++ b/egomimic/utils/aws/sql_tutorial.ipynb @@ -1,202 +1,244 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": 4, - "id": "75093106", - "metadata": {}, - "outputs": [], - "source": [ - "from sqlalchemy import (\n", - " Boolean,\n", - " Column,\n", - " Float,\n", - " Integer,\n", - " MetaData,\n", - " String,\n", - " Table,\n", - " text,\n", - " update,\n", - ")\n", - "\n", - "from egomimic.utils.aws.aws_sql import (\n", - " TableRow,\n", - " create_default_engine,\n", - " episode_table_to_df,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "1bc257dd", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Tables in schema 'app': ['episodes']\n" - ] - } - ], - "source": [ - "engine = create_default_engine()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a40642da", - "metadata": {}, - "outputs": [], - "source": [ - "df = episode_table_to_df(engine)\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "21ee3b2d", - "metadata": {}, - "source": [ - "## Example Useful Functions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "55b01865", - "metadata": {}, - "outputs": [], - "source": [ - "# Add Episode Test\n", - "episode = TableRow(\n", - " episode_hash=1761408819,\n", - " operator=\"test\",\n", - " lab=\"lab-x\",\n", - " num_frames=-1,\n", - " task=\"bimanual_test\",\n", - " task_description=\"Dummy row for table inspection\",\n", - " scene=\"kitchen\",\n", - " objects=\"cup,plate,spoon\",\n", - " processed_path=\"\",\n", - " zarr_processed_path=\"\",\n", - " zarr_processing_error=\"\",\n", - " zarr_mp4_path=\"\",\n", - " mp4_path=\"\",\n", - " embodiment=\"aria\",\n", - " robot_name=\"\",\n", - " is_eval=False,\n", - " eval_score=0.0,\n", - " eval_success=False,\n", - ")\n", - "# Adding an episode\n", - "# add_episode(engine, episode)\n", - "\n", - "# Update Episode Test\n", - "# episode.operator = \"simar\"\n", - "# update_episode(engine, episode)\n", - "\n", - "# Get Table Row from Episode Hash\n", - "# episode_hash_to_table_row(engine, 123456)\n", - "\n", - "# Delete episodes by hash\n", - "# delete_episodes(engine, [123456, 123457])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fdad1d0f", - "metadata": {}, - "outputs": [], - "source": [ - "def drop_table(table_name):\n", - " with engine.connect() as connection:\n", - " connection.execute(text(f\"DROP TABLE IF EXISTS app.{table_name} CASCADE;\"))\n", - " connection.commit()\n", - " print(f\"Dropped table '{table_name}' from schema 'app' if it existed.\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "770a6964", - "metadata": {}, - "outputs": [], - "source": [ - "def create_table():\n", - " metadata = MetaData(schema=\"app\")\n", - "\n", - " Table(\n", - " \"episodes\",\n", - " metadata,\n", - " Column(\"episode_hash\", String, primary_key=True),\n", - " Column(\"operator\", String),\n", - " Column(\"lab\", String),\n", - " Column(\"num_frames\", Integer),\n", - " Column(\"task\", String),\n", - " Column(\"task_description\", String),\n", - " Column(\"scene\", String),\n", - " Column(\n", - " \"objects\", String\n", - " ), # Store as JSON or comma-separated list of object names\n", - " Column(\"processed_path\", String),\n", - " Column(\"zarr_processed_path\", String),\n", - " Column(\"zarr_processing_error\", String),\n", - " Column(\"zarr_mp4_path\", String),\n", - " Column(\"mp4_path\", String),\n", - " Column(\"is_deleted\", Boolean),\n", - " Column(\"embodiment\", String),\n", - " Column(\"robot_name\", String),\n", - " Column(\"is_eval\", Boolean),\n", - " Column(\"eval_score\", Float),\n", - " Column(\"eval_success\", Boolean),\n", - " )\n", - "\n", - " metadata.create_all(engine)\n", - " print(\"Created table 'episodes' in schema 'app'.\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cbb7d2a8", - "metadata": {}, - "outputs": [], - "source": [ - "def delete_episodes_by_task(task_name: str):\n", - " episodes_tbl = Table(\"episodes\", MetaData(), autoload_with=engine, schema=\"app\")\n", - " stmt = (\n", - " update(episodes_tbl)\n", - " .where(episodes_tbl.c.task == task_name)\n", - " .values(is_deleted=True)\n", - " )\n", - " with engine.begin() as conn:\n", - " conn.execute(stmt)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "emimic (3.11.14)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "75093106", + "metadata": {}, + "outputs": [], + "source": [ + "from sqlalchemy import (\n", + " MetaData,\n", + " Table,\n", + " Column,\n", + " Integer,\n", + " String,\n", + " Boolean,\n", + " Float,\n", + " text,\n", + " update,\n", + " select\n", + ")\n", + "from egomimic.utils.aws.aws_sql import (\n", + " TableRow,\n", + " add_episode,\n", + " update_episode,\n", + " create_default_engine,\n", + " episode_hash_to_table_row,\n", + " delete_episodes,\n", + " episode_table_to_df,\n", + " delete_all_episodes,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bc257dd", + "metadata": {}, + "outputs": [], + "source": [ + "engine = create_default_engine()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a40642da", + "metadata": {}, + "outputs": [], + "source": [ + "df = episode_table_to_df(engine)\n", + "df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c73f9b8", + "metadata": {}, + "outputs": [], + "source": [ + "from egomimic.utils.aws.aws_sql import timestamp_ms_to_episode_hash\n", + "hash = timestamp_ms_to_episode_hash(1764285228498)\n", + "row = episode_hash_to_table_row(engine, hash)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "670c0977", + "metadata": {}, + "outputs": [], + "source": [ + "print(row.zarr_mp4_path)\n", + "print(row.zarr_processed_path)\n", + "print(row.zarr_processing_error)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e734cd4", + "metadata": {}, + "outputs": [], + "source": [ + "row.description" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "440cc850", + "metadata": {}, + "outputs": [], + "source": [ + "row.zarr_processing_error = ''\n", + "update_episode(engine, row)\n" + ] + }, + { + "cell_type": "markdown", + "id": "21ee3b2d", + "metadata": {}, + "source": [ + "## Example Useful Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55b01865", + "metadata": {}, + "outputs": [], + "source": [ + "# Add Episode Test\n", + "episode = TableRow(\n", + " episode_hash=1761408819,\n", + " operator=\"test\",\n", + " lab=\"lab-x\",\n", + " num_frames=-1,\n", + " task=\"bimanual_test\",\n", + " task_description=\"Dummy row for table inspection\",\n", + " scene=\"kitchen\",\n", + " objects=\"cup,plate,spoon\",\n", + " processed_path=\"\",\n", + " zarr_processed_path=\"\",\n", + " zarr_processing_error=\"\",\n", + " zarr_mp4_path=\"\",\n", + " mp4_path=\"\",\n", + " embodiment=\"aria\",\n", + " robot_name=\"\",\n", + " is_eval=False,\n", + " eval_score=0.0,\n", + " eval_success=False,\n", + ")\n", + "# Adding an episode\n", + "# add_episode(engine, episode)\n", + "\n", + "# Update Episode Test\n", + "# episode.operator = \"simar\"\n", + "# update_episode(engine, episode)\n", + "\n", + "# Get Table Row from Episode Hash\n", + "# episode_hash_to_table_row(engine, 123456)\n", + "\n", + "# Delete episodes by hash\n", + "# delete_episodes(engine, [123456, 123457])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdad1d0f", + "metadata": {}, + "outputs": [], + "source": [ + "def drop_table(table_name):\n", + " with engine.connect() as connection:\n", + " connection.execute(text(f\"DROP TABLE IF EXISTS app.{table_name} CASCADE;\"))\n", + " connection.commit()\n", + " print(f\"Dropped table '{table_name}' from schema 'app' if it existed.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "770a6964", + "metadata": {}, + "outputs": [], + "source": [ + "def create_table():\n", + " metadata = MetaData(schema=\"app\")\n", + "\n", + " Table(\n", + " \"episodes\",\n", + " metadata,\n", + " Column(\"episode_hash\", String, primary_key=True),\n", + " Column(\"operator\", String),\n", + " Column(\"lab\", String),\n", + " Column(\"num_frames\", Integer),\n", + " Column(\"task\", String),\n", + " Column(\"task_description\", String),\n", + " Column(\"scene\", String),\n", + " Column(\n", + " \"objects\", String\n", + " ), # Store as JSON or comma-separated list of object names\n", + " Column(\"processed_path\", String),\n", + " Column(\"zarr_processed_path\", String),\n", + " Column(\"zarr_processing_error\", String),\n", + " Column(\"zarr_mp4_path\", String),\n", + " Column(\"mp4_path\", String),\n", + " Column(\"is_deleted\", Boolean),\n", + " Column(\"embodiment\", String),\n", + " Column(\"robot_name\", String),\n", + " Column(\"is_eval\", Boolean),\n", + " Column(\"eval_score\", Float),\n", + " Column(\"eval_success\", Boolean),\n", + " )\n", + "\n", + " metadata.create_all(engine)\n", + " print(\"Created table 'episodes' in schema 'app'.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cbb7d2a8", + "metadata": {}, + "outputs": [], + "source": [ + "def delete_episodes_by_task(task_name: str):\n", + " episodes_tbl = Table(\"episodes\", MetaData(), autoload_with=engine, schema=\"app\")\n", + " stmt = (\n", + " update(episodes_tbl)\n", + " .where(episodes_tbl.c.task == task_name)\n", + " .values(is_deleted=True)\n", + " )\n", + " with engine.begin() as conn:\n", + " conn.execute(stmt)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "emimic", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/egomimic/utils/egomimicUtils.py b/egomimic/utils/egomimicUtils.py index d5f07c2c..670c3c43 100644 --- a/egomimic/utils/egomimicUtils.py +++ b/egomimic/utils/egomimicUtils.py @@ -1,5 +1,12 @@ -import argparse -import math +import subprocess +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import torchvision.transforms.functional as TF +import torch +from scipy.spatial.transform import Rotation +import pytorch_kinematics as pk +import egomimic import os from numbers import Number from pathlib import Path @@ -545,10 +552,269 @@ def draw_rotation_text( return image +def render_3d_traj_frames( + trajs, + labels=None, + idx=0, + stride=1, + mode="time", # "time" or "rotate" + elev=20, + azim_start=0, + azim_end=360, + tail=None, + equal_axes=True, + figsize=(6, 6), + dpi=150, +): + """ + Render frames (as uint8 RGB images) for multiple 3D trajectories. + + Args: + trajs: list of array-like, each shape (N, T, 3) + labels: list[str] legend labels, same length as trajs + idx: which chunk to render + stride: timestep stride + mode: "time" animates over timesteps; "rotate" rotates camera around full traj + elev/azim_start/azim_end: camera params + tail: in "time" mode, show only last `tail` points (int) or full history (None) + equal_axes: lock x/y/z ranges to be equal-ish + figsize, dpi: output frame size -def draw_actions( - im, type, color, actions, extrinsics, intrinsics, arm="both", kinematics_solver=None + Returns: + frames: list[np.ndarray] each shape (H, W, 3), dtype uint8 (RGB) + """ + if not isinstance(trajs, (list, tuple)) or len(trajs) == 0: + raise ValueError("trajs must be a non-empty list of arrays shaped (N, T, 3).") + + trajs_np = [np.asarray(a) for a in trajs] + base = trajs_np[0] + if base.ndim != 3 or base.shape[-1] != 3: + raise ValueError(f"Expected (N, T, 3). Got {base.shape}") + + n, t, _ = base.shape + for k, a in enumerate(trajs_np): + if a.shape != base.shape: + raise ValueError(f"All trajs must share the same shape. traj0={base.shape}, traj{k}={a.shape}") + + if not (0 <= idx < n): + raise IndexError(f"idx {idx} out of range for N={n}") + + if labels is None: + labels = [f"traj{k}" for k in range(len(trajs_np))] + if len(labels) != len(trajs_np): + raise ValueError("labels must have the same length as trajs.") + + if mode not in ("time", "rotate"): + raise ValueError('mode must be "time" or "rotate".') + + ts = np.arange(0, t, stride) + t_len = len(ts) + + fig = plt.figure(figsize=figsize, dpi=dpi) + ax = fig.add_subplot(111, projection="3d") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.view_init(elev=elev, azim=azim_start) + + # Fixed axis limits (prevents jitter) + if equal_axes: + all_pts = np.concatenate([a[idx, ts] for a in trajs_np], axis=0) + mins = all_pts.min(axis=0) + maxs = all_pts.max(axis=0) + center = (mins + maxs) / 2.0 + span = (maxs - mins).max() + half = span / 2.0 if span > 0 else 1.0 + ax.set_xlim(center[0] - half, center[0] + half) + ax.set_ylim(center[1] - half, center[1] + half) + ax.set_zlim(center[2] - half, center[2] + half) + + # Create lines + lines = [] + for lab in labels: + (ln,) = ax.plot([], [], [], label=lab) + lines.append(ln) + ax.legend() + + # Precompute trajectories for rotate mode + if mode == "rotate": + full_xyz = [a[idx, ts] for a in trajs_np] + for ln, xyz in zip(lines, full_xyz): + x, y, z = xyz.T + ln.set_data(x, y) + ln.set_3d_properties(z) + + n_frames = int(abs(azim_end - azim_start)) + 1 if azim_end != azim_start else 360 + azims = np.linspace(azim_start, azim_end, n_frames) + + def draw_frame(fi): + ax.view_init(elev=elev, azim=float(azims[fi])) + + else: # time mode + def draw_frame(fi): + end = fi + 1 + start = 0 if tail is None else max(0, end - int(tail)) + for ln, a in zip(lines, trajs_np): + seg = a[idx, ts[start:end]] + x, y, z = seg.T + ln.set_data(x, y) + ln.set_3d_properties(z) + + # Render frames into RGB arrays + frames = [] + canvas = fig.canvas + for fi in range((len(np.linspace(azim_start, azim_end, int(abs(azim_end - azim_start)) + 1)) if mode == "rotate" else t_len)): + draw_frame(fi) + canvas.draw() + rgba = np.asarray(canvas.buffer_rgba()) + rgb = rgba[..., :3].copy() + frames.append(rgb) + + plt.close(fig) + return frames + +def render_3d_traj_frames_NT3( + trajs, + labels=None, + stride=1, + mode="time", # "time" or "rotate" + elev=20, + azim_start=0, + azim_end=360, + tail=None, + equal_axes=True, + figsize=(6, 6), + dpi=150, ): + """ + Render frames (uint8 RGB) for multiple 3D trajectories where each traj is shaped (N, T, 3): + - N: number of frames (global time) + - T: action chunk length (per-frame mini-trajectory) + - 3: xyz + + Semantics: + - In "time" mode: frame i shows the chunk traj[i, ...] (optionally with tail over i). + - In "rotate" mode: renders a single camera rotation using the chunk at the last frame (or + an aggregated chunk if tail is set). + + Args: + trajs: list of array-like, each shape (N, T, 3) + labels: list[str], same length as trajs + stride: stride over N (frames) + mode: "time" or "rotate" + elev, azim_start, azim_end: camera params + tail: if set in "time" mode, overlay last `tail` chunks (over N) on each frame + equal_axes: lock x/y/z ranges + figsize, dpi: output frame size + + Returns: + frames: list[np.ndarray] each (H, W, 3) uint8 + """ + if not isinstance(trajs, (list, tuple)) or len(trajs) == 0: + raise ValueError("trajs must be a non-empty list of arrays shaped (N, T, 3).") + + trajs_np = [np.asarray(a) for a in trajs] + base = trajs_np[0] + if base.ndim != 3 or base.shape[-1] != 3: + raise ValueError(f"Expected (N, T, 3). Got {base.shape}") + + n_frames, t_chunk, _ = base.shape + for k, a in enumerate(trajs_np): + if a.shape != base.shape: + raise ValueError(f"All trajs must share shape. traj0={base.shape}, traj{k}={a.shape}") + + if labels is None: + labels = [f"traj{k}" for k in range(len(trajs_np))] + if len(labels) != len(trajs_np): + raise ValueError("labels must have the same length as trajs.") + + if mode not in ("time", "rotate"): + raise ValueError('mode must be "time" or "rotate".') + + ns = np.arange(0, n_frames, stride) + n_len = len(ns) + + fig = plt.figure(figsize=figsize, dpi=dpi) + ax = fig.add_subplot(111, projection="3d") + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + ax.view_init(elev=elev, azim=azim_start) + + # Fixed axis limits to prevent jitter + if equal_axes: + all_pts = np.concatenate([a[ns].reshape(-1, 3) for a in trajs_np], axis=0) + mins = all_pts.min(axis=0) + maxs = all_pts.max(axis=0) + center = (mins + maxs) / 2.0 + span = (maxs - mins).max() + half = span / 2.0 if span > 0 else 1.0 + ax.set_xlim(center[0] - half, center[0] + half) + ax.set_ylim(center[1] - half, center[1] + half) + ax.set_zlim(center[2] - half, center[2] + half) + + lines = [] + for lab in labels: + (ln,) = ax.plot([], [], [], label=lab) + lines.append(ln) + ax.legend() + + def set_line_from_xyz(ln, xyz): + x, y, z = xyz.T + ln.set_data(x, y) + ln.set_3d_properties(z) + + if mode == "rotate": + # Choose which chunk to show while rotating: + # - if tail is None: show the last chunk + # - else: show an aggregated polyline by concatenating last `tail` chunks + if tail is None: + xyzs = [a[ns[-1]] for a in trajs_np] # (T, 3) + else: + start = max(0, len(ns) - int(tail)) + xyzs = [a[ns[start:]].reshape(-1, 3) for a in trajs_np] # (tail*T, 3) + + for ln, xyz in zip(lines, xyzs): + set_line_from_xyz(ln, xyz) + + n_frames_rot = int(abs(azim_end - azim_start)) + 1 if azim_end != azim_start else 360 + azims = np.linspace(azim_start, azim_end, n_frames_rot) + + def draw_frame(fi): + ax.view_init(elev=elev, azim=float(azims[fi])) + + out_len = len(azims) + + else: + # time mode: frame i shows chunk at ns[i]; optionally overlay last `tail` chunks + def draw_frame(fi): + n_idx = ns[fi] + if tail is None: + for ln, a in zip(lines, trajs_np): + set_line_from_xyz(ln, a[n_idx]) # (T, 3) + else: + start_f = max(0, fi + 1 - int(tail)) + for ln, a in zip(lines, trajs_np): + xyz = a[ns[start_f:fi + 1]].reshape(-1, 3) # (tail*T, 3) + set_line_from_xyz(ln, xyz) + + out_len = n_len + + frames = [] + canvas = fig.canvas + for fi in range(out_len): + draw_frame(fi) + canvas.draw() + rgba = np.asarray(canvas.buffer_rgba()) + frames.append(rgba[..., :3].copy()) + + plt.close(fig) + return frames + +def xyzw_to_wxyz(xyzw): + return np.concatenate([xyzw[..., 3:4], xyzw[..., :3]], axis=-1) + +def draw_actions(im, type, color, actions, extrinsics, intrinsics, arm="both", kinematics_solver=None): """ args: im: (H, W, C) @@ -748,6 +1014,41 @@ def ee_orientation_to_cam_frame(ee_orientation_base, T_cam_base): ) return ee_orientation_cam, batched_ypr +def prep_frame(frame: np.ndarray, H: int, W: int) -> np.ndarray | None: + if frame is None: + print("Frame is None") + return None + + if frame.dtype != np.uint8: + frame = frame.astype(np.uint8, copy=False) + + if frame.shape != (H, W, 3): + print(f"Frame shape {frame.shape} does not match expected shape {H}x{W}x3") + return None + + frame = np.ascontiguousarray(frame) + return frame + +def start_ffmpeg_mp4(out_path: str, width: int, height: int, fps: int, pix_fmt: str = "rgb24"): + # pix_fmt: "rgb24" if frames are RGB uint8; use "bgr24" if OpenCV frames + cmd = [ + "ffmpeg", + "-y", + "-f", "rawvideo", + "-vcodec", "rawvideo", + "-pix_fmt", pix_fmt, + "-s", f"{width}x{height}", + "-r", str(fps), + "-i", "pipe:0", + "-an", + "-c:v", "libx264", + "-preset", "veryfast", + "-crf", "23", + "-pix_fmt", "yuv420p", # best compatibility + out_path, + ] + return subprocess.Popen(cmd, stdin=subprocess.PIPE) + def batched_rotation_matrices_to_euler_angles(batch_R): """ diff --git a/requirements-ray.txt b/requirements-ray.txt index df87d9d3..562ab8b9 100644 --- a/requirements-ray.txt +++ b/requirements-ray.txt @@ -9,3 +9,4 @@ opencv-python-headless click==8.2.1 tqdm Pillow +mediapy \ No newline at end of file