diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..c72d8540 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,9 @@ +# Repo Agent Rules + +## Shell / Command Execution +to run commands in the interactive shell make sure to source emimic/bin/activate + +Apply this before running project Python tooling (for example: `python`, `pytest`, `pip`). + +## Model settings +use plan mode for anything except extremely simple tasks \ No newline at end of file diff --git a/egomimic/hydra_configs/data/test_multi_zarr.yaml b/egomimic/hydra_configs/data/test_multi_zarr.yaml new file mode 100644 index 00000000..8b1953bf --- /dev/null +++ b/egomimic/hydra_configs/data/test_multi_zarr.yaml @@ -0,0 +1,67 @@ +_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper +train_datasets: + dataset1: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver + resolver: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver + folder_path: /nethome/paphiwetsa3/flash/datasets/test_zarr/ + key_map: + front_img_1: #batch key + key_type: camera_keys # key type + zarr_key: observations.images.front_img_1 # dataset key + right_wrist_img: + key_type: camera_keys + zarr_key: observations.images.right_wrist_img + left_wrist_img: + key_type: camera_keys + zarr_key: observations.images.left_wrist_img + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose + horizon: 100 + joint_positions: + key_type: proprio_keys + zarr_key: observations.state.joint_positions + horizon: 100 + actions_cartesian: + key_type: action_keys + zarr_key: actions_cartesian + horizon: 100 + mode: total +valid_datasets: + dataset1: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver + resolver: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver + folder_path: /nethome/paphiwetsa3/flash/datasets/test_zarr/ + key_map: + front_img_1: #batch key + key_type: camera_keys # key type + zarr_key: observations.images.front_img_1 # dataset key + right_wrist_img: + key_type: camera_keys + zarr_key: observations.images.right_wrist_img + left_wrist_img: + key_type: camera_keys + zarr_key: observations.images.left_wrist_img + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose + horizon: 100 + joint_positions: + key_type: proprio_keys + zarr_key: observations.state.joint_positions + horizon: 100 + actions_cartesian: + key_type: action_keys + zarr_key: actions_cartesian + horizon: 100 + mode: total +train_dataloader_params: + dataset1: + batch_size: 2 + num_workers: 10 +valid_dataloader_params: + dataset1: + batch_size: 2 + num_workers: 10 diff --git a/egomimic/pl_utils/pl_data_utils.py b/egomimic/pl_utils/pl_data_utils.py index 21bc708c..cc1a828c 100644 --- a/egomimic/pl_utils/pl_data_utils.py +++ b/egomimic/pl_utils/pl_data_utils.py @@ -81,7 +81,7 @@ def train_dataloader(self): iterables = dict() for dataset_name, dataset in self.train_datasets.items(): dataset_params = self.train_dataloader_params.get(dataset_name, {}) - iterables[dataset.embodiment] = DataLoader( + iterables[dataset_name] = DataLoader( dataset, shuffle=True, collate_fn=self.collate_fn, diff --git a/egomimic/rldb/utils.py b/egomimic/rldb/utils.py index d09cf17b..86638451 100644 --- a/egomimic/rldb/utils.py +++ b/egomimic/rldb/utils.py @@ -1236,7 +1236,7 @@ def infer_shapes_from_batch(self, batch): self.shapes_infered = True - def infer_norm_from_dataset(self, dataset): + def infer_norm_from_dataset_lerobot(self, dataset): """ dataset: huggingface dataset backed by pyarrow returns: dictionary of means and stds for proprio and action keys @@ -1289,6 +1289,80 @@ def infer_norm_from_dataset(self, dataset): logger.info("[NormStats] Finished norm inference") + def infer_norm_from_dataset(self, dataset): + """ + dataset: huggingface dataset or zarr dataset + returns: dictionary of means and stds for proprio and action keys + """ + norm_columns = [] + + embodiment = dataset.embodiment + if isinstance(embodiment, str): + embodiment = get_embodiment_id(embodiment) + + norm_columns.extend(self.keys_of_type("proprio_keys")) + norm_columns.extend(self.keys_of_type("action_keys")) + + logger.info( + f"[NormStats] Starting norm inference for embodiment={embodiment}, " + f"{len(norm_columns)} columns" + ) + + def get_zarr_data(ds, col): + if hasattr(ds, "episode_reader"): + # ZarrDataset + if col in ds.episode_reader._store: + return ds.episode_reader._store[col][:] + return None + elif hasattr(ds, "datasets"): + # MultiDataset wrapper + data_list = [] + for d in ds.datasets.values(): + res = get_zarr_data(d, col) + if res is not None: + data_list.append(res) + if data_list: + return np.concatenate(data_list, axis=0) + return None + + for column in norm_columns: + if not self.is_key_with_embodiment(column, embodiment): + continue + column_name = self.keyname_to_lerobot_key(column, embodiment) + logger.info(f"[NormStats] Processing column={column_name}") + + column_data = get_zarr_data(dataset, column_name) + + if column_data is None: + logger.warning(f"Skipping {column_name}, data not found given dataset type") + continue + + if column_data.ndim not in (2, 3): + raise ValueError( + f"Column {column} has shape {column_data.shape}, " + "expected 2 or 3 dims" + ) + + mean = np.mean(column_data, axis=0) + std = np.std(column_data, axis=0) + minv = np.min(column_data, axis=0) + maxv = np.max(column_data, axis=0) + median = np.median(column_data, axis=0) + q1 = np.percentile(column_data, 1, axis=0) + q99 = np.percentile(column_data, 99, axis=0) + + self.norm_stats[embodiment][column] = { + "mean": torch.from_numpy(mean).float(), + "std": torch.from_numpy(std).float(), + "min": torch.from_numpy(minv).float(), + "max": torch.from_numpy(maxv).float(), + "median": torch.from_numpy(median).float(), + "quantile_1": torch.from_numpy(q1).float(), + "quantile_99": torch.from_numpy(q99).float(), + } + + logger.info("[NormStats] Finished norm inference") + def viz_img_key(self): """ Get the key that should be used for offline visualization diff --git a/egomimic/rldb/zarr/__init__.py b/egomimic/rldb/zarr/__init__.py new file mode 100644 index 00000000..70219109 --- /dev/null +++ b/egomimic/rldb/zarr/__init__.py @@ -0,0 +1,17 @@ +""" +Zarr-based dataset implementations for EgoVerse. +""" + +from egomimic.rldb.zarr.zarr_dataset_multi import ( + EpisodeResolver, + MultiDataset, + ZarrDataset, + ZarrEpisode, +) + +__all__ = [ + "EpisodeResolver", + "MultiDataset", + "ZarrDataset", + "ZarrEpisode", +] diff --git a/egomimic/rldb/zarr/action_chunk_transforms.py b/egomimic/rldb/zarr/action_chunk_transforms.py new file mode 100644 index 00000000..59de1081 --- /dev/null +++ b/egomimic/rldb/zarr/action_chunk_transforms.py @@ -0,0 +1,647 @@ +""" +Embodiment-dependent action chunk transforms for ZarrDataset. + +Replicates the prestacking transformations from aria_to_lerobot.py / eva_to_lerobot.py, +applied at load time instead of at data creation time. Raw action frames are loaded +as (action_horizon, action_dim) and interpolated to (chunk_length, action_dim). + +Translation (xyz) and gripper dimensions use linear interpolation. +Rotation (euler ypr) dimensions use np.unwrap before interpolation and rewrap after, +matching the behaviour of egomimicUtils.interpolate_arr_euler. +""" + +from __future__ import annotations +from abc import abstractmethod + +import numpy as np +from projectaria_tools.core.sophus import SE3 +from scipy.spatial.transform import Rotation as R + +from egomimic.utils.pose_utils import ( + _interpolate_euler, + _interpolate_quat_wxyz, + _matrix_to_xyzwxyz, + _interpolate_linear, + _matrix_to_xyzypr, + _xyzwxyz_to_matrix, +) + +# --------------------------------------------------------------------------- +# Base Transform +# --------------------------------------------------------------------------- + + +class Transform: + """Base Class for all transforms.""" + + @abstractmethod + def transform(self, batch: dict) -> dict: + """Transform the data.""" + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Interpolation Transforms +# --------------------------------------------------------------------------- + + +class InterpolatePose(Transform): + """Interpolate a pose chunk of shape (T, 6) or (T, 7).""" + + def __init__( + self, + new_chunk_length: int, + action_key: str, + output_action_key: str, + stride: int = 1, + is_quat: bool = False, + ): + if stride <= 0: + raise ValueError(f"stride must be positive, got {stride}") + self.new_chunk_length = new_chunk_length + self.action_key = action_key + self.output_action_key = output_action_key + self.stride = int(stride) + self.is_quat = is_quat + + def transform(self, batch: dict) -> dict: + actions = np.asarray(batch[self.action_key]) + actions = actions[:: self.stride] + if self.is_quat: + if actions.ndim != 2 or actions.shape[-1] != 7: + raise ValueError( + f"InterpolatePose expects (T, 7) when is_quat=True, got " + f"{actions.shape} for key '{self.action_key}'" + ) + batch[self.output_action_key] = _interpolate_quat_wxyz( + actions, self.new_chunk_length + ) + else: + if actions.ndim != 2 or actions.shape[-1] != 6: + raise ValueError( + f"InterpolatePose expects (T, 6), got {actions.shape} for key " + f"'{self.action_key}'" + ) + batch[self.output_action_key] = _interpolate_euler( + actions, self.new_chunk_length + ) + return batch + + +class InterpolateLinear(Transform): + """Interpolate any chunk of shape (T, D) with linear interpolation.""" + + def __init__( + self, + new_chunk_length: int, + action_key: str, + output_action_key: str, + stride: int = 1, + ): + if stride <= 0: + raise ValueError(f"stride must be positive, got {stride}") + self.new_chunk_length = new_chunk_length + self.action_key = action_key + self.output_action_key = output_action_key + self.stride = int(stride) + + def transform(self, batch: dict) -> dict: + actions = np.asarray(batch[self.action_key]) + if actions.ndim != 2: + raise ValueError( + f"InterpolateLinear expects (T, D), got {actions.shape} for key " + f"'{self.action_key}'" + ) + actions = actions[:: self.stride] + batch[self.output_action_key] = _interpolate_linear( + actions, self.new_chunk_length + ) + return batch + + +# --------------------------------------------------------------------------- +# Coordinate Transforms +# --------------------------------------------------------------------------- + + +def _xyzypr_to_matrix(xyzypr: np.ndarray) -> np.ndarray: + """ + args: + xyzypr: (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyzypr.ndim != 2 or xyzypr.shape[-1] != 6: + raise ValueError(f"Expected (B, 6) array, got shape {xyzypr.shape}") + + B = xyzypr.shape[0] + dtype = xyzypr.dtype if np.issubdtype(xyzypr.dtype, np.floating) else np.float64 + + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + # Input is [yaw, pitch, roll], so use ZYX order (Rz @ Ry @ Rx). + mats[:, :3, :3] = R.from_euler("ZYX", xyzypr[:, 3:6], degrees=False).as_matrix() + mats[:, :3, 3] = xyzypr[:, :3] + + return mats + + +class ActionChunkCoordinateFrameTransform(Transform): + def __init__( + self, + target_world: str, + chunk_world: str, + transformed_key_name: str, + extra_batch_key: dict = None, + is_quat: bool = False, + ): + """ + args: + target_world: + chunk_world: + transformed_key_name: + is_quat: if True, inputs are xyz + quat(wxyz); otherwise xyz + ypr. + """ + self.target_world = target_world + self.chunk_world = chunk_world + self.transformed_key_name = transformed_key_name + self.extra_batch_key = extra_batch_key + self.is_quat = is_quat + + def transform(self, batch): + """ + args: + batch: + if is_quat=False, inputs are xyz + ypr. + if is_quat=True, inputs are xyz + quat(wxyz). + Input shape validation is delegated to the selected to-matrix helper. + transformed_key_name: str, name of the new key to store the transformed chunk world in + + returns + batch with new key containing transformed chunk world in target frame: + if is_quat=False: (T, 6) xyz + ypr + if is_quat=True: (T, 7) xyz + quat(wxyz) + """ + batch.update(self.extra_batch_key or {}) + target_world = np.asarray(batch[self.target_world]) + chunk_world = np.asarray(batch[self.chunk_world]) + to_matrix_fn = _xyzwxyz_to_matrix if self.is_quat else _xyzypr_to_matrix + + # Convert to SE3 for transformation + target_se3 = SE3.from_matrix(to_matrix_fn(target_world[None, :])[0]) # (4, 4) + chunk_se3 = SE3.from_matrix(to_matrix_fn(chunk_world)) # (T, 4, 4) + + # Compute relative transform and apply to chunk + chunk_in_target_frame = target_se3.inverse() @ chunk_se3 + chunk_mats = chunk_in_target_frame.to_matrix() + if chunk_mats.ndim == 2: + chunk_mats = chunk_mats[None, ...] + chunk_in_target_frame = ( + _matrix_to_xyzwxyz(chunk_mats) + if self.is_quat + else _matrix_to_xyzypr(chunk_mats) + ) + + # Store transformed chunk back in batch + batch[self.transformed_key_name] = chunk_in_target_frame + + return batch + + +class QuaternionPoseToYPR(Transform): + """Convert a single pose from xyz + quat(x,y,z,w) to xyz + ypr.""" + + def __init__(self, pose_key: str, output_key: str): + self.pose_key = pose_key + self.output_key = output_key + + def transform(self, batch: dict) -> dict: + pose = np.asarray(batch[self.pose_key]) + if pose.shape != (7,): + raise ValueError( + f"QuaternionPoseToYPR expects shape (7,), got {pose.shape} for key " + f"'{self.pose_key}'" + ) + xyz = pose[:3] + ypr = R.from_quat(pose[3:7]).as_euler("ZYX", degrees=False) + batch[self.output_key] = np.concatenate([xyz, ypr], axis=0) + return batch + + +class PoseCoordinateFrameTransform(Transform): + """Transform a single pose into a target frame pose.""" + + def __init__( + self, + target_world: str, + pose_world: str, + transformed_key_name: str, + is_quat: bool = False, + ): + self.target_world = target_world + self.pose_world = pose_world + self.transformed_key_name = transformed_key_name + self.is_quat = is_quat + self._chunk_transform = ActionChunkCoordinateFrameTransform( + target_world=target_world, + chunk_world=pose_world, + transformed_key_name=transformed_key_name, + is_quat=is_quat, + ) + + def transform(self, batch: dict) -> dict: + pose_world = np.asarray(batch[self.pose_world]) + expected_shape = (7,) if self.is_quat else (6,) + if pose_world.shape != expected_shape: + raise ValueError( + f"Expected pose_world shape {expected_shape}, got {pose_world.shape}" + ) + + transformed = self._chunk_transform.transform( + { + self.target_world: batch[self.target_world], + self.pose_world: pose_world[None, :], + } + ) + batch[self.transformed_key_name] = np.asarray( + transformed[self.transformed_key_name] + )[0] + return batch + + +class DeleteKeys(Transform): + def __init__(self, keys_to_delete): + self.keys_to_delete = keys_to_delete + + def transform(self, batch): + for key in self.keys_to_delete: + batch.pop(key, None) + return batch + + +class XYZWXYZ_to_XYZYPR(Transform): + """Convert listed keys from xyz+quat(wxyz) to xyz+ypr in-place.""" + + def __init__(self, keys: list[str]): + self.keys = list(keys) + + def transform(self, batch: dict) -> dict: + for key in self.keys: + value = np.asarray(batch[key]) + if value.ndim == 1 and value.shape[0] == 7: + batch[key] = _matrix_to_xyzypr(_xyzwxyz_to_matrix(value[None, :]))[0] + elif value.ndim == 2 and value.shape[1] == 7: + batch[key] = _matrix_to_xyzypr(_xyzwxyz_to_matrix(value)) + else: + raise ValueError( + f"XYZWXYZ_to_XYZYPR expects key '{key}' to have shape (7,) " + f"or (T, 7), got {value.shape}" + ) + return batch + + +class CartesianWithGripperCoordinateTransform(Transform): + def __init__( + self, + left_target_world: str, + right_target_world: str, + chunk_world: str, + transformed_key_name: str, + extra_batch_key: dict = None, + ): + """ + args: + left_target_world: string key for left target world pose in batch (6D: xyz + ypr) + right_target_world: string key for right target world pose in batch (6D: xyz + ypr) + chunk_world: string key for chunk world pose in batch (14D: xyz + ypr + gripper * 2 arms) + transformed_key_name: string key to store transformed chunk world in batch (14D) + """ + self.left_target_world = left_target_world + self.right_target_world = right_target_world + self.chunk_world = chunk_world + self.transformed_key_name = transformed_key_name + self.extra_batch_key = extra_batch_key + + def transform(self, batch): + """ + args: + batch: + left_target_world: numpy(6): xyz + ypr + right_target_world: numpy(6): xyz + ypr + chunk_world: numpy(T, 14): [left xyz+ypr+gripper, right xyz+ypr+gripper] + transformed_key_name: str, name of the new key to store the transformed chunk world in + + returns + batch with new key containing transformed chunk world in target frame: (T, 14) + """ + batch.update(self.extra_batch_key or {}) + left_target_world = batch[self.left_target_world] + right_target_world = batch[self.right_target_world] + chunk_world = batch[self.chunk_world] + + if left_target_world.shape != (6,): + raise ValueError( + f"Expected left_target_world shape (6,), got {left_target_world.shape}" + ) + if right_target_world.shape != (6,): + raise ValueError( + f"Expected right_target_world shape (6,), got {right_target_world.shape}" + ) + if chunk_world.ndim != 2 or chunk_world.shape[1] != 14: + raise ValueError( + f"Expected chunk_world shape (T, 14), got {chunk_world.shape}" + ) + + # Chunk layout: [left xyz+ypr+gripper, right xyz+ypr+gripper] + left_pose_world = chunk_world[:, :6] + right_pose_world = chunk_world[:, 7:13] + + left_target_se3 = SE3.from_matrix( + _xyzypr_to_matrix(left_target_world[None, :])[0] + ) + right_target_se3 = SE3.from_matrix( + _xyzypr_to_matrix(right_target_world[None, :])[0] + ) + left_target_inv = left_target_se3.inverse() + right_target_inv = right_target_se3.inverse() + + left_pose_in_target = _matrix_to_xyzypr( + ( + left_target_inv @ SE3.from_matrix(_xyzypr_to_matrix(left_pose_world)) + ).to_matrix() + ) + right_pose_in_target = _matrix_to_xyzypr( + ( + right_target_inv @ SE3.from_matrix(_xyzypr_to_matrix(right_pose_world)) + ).to_matrix() + ) + + chunk_in_target_frame = np.empty_like(chunk_world) + chunk_in_target_frame[:, :6] = left_pose_in_target + chunk_in_target_frame[:, 6] = chunk_world[:, 6] # left gripper unchanged + chunk_in_target_frame[:, 7:13] = right_pose_in_target + chunk_in_target_frame[:, 13] = chunk_world[:, 13] # right gripper unchanged + + batch[self.transformed_key_name] = chunk_in_target_frame + return batch + + +# --------------------------------------------------------------------------- +# Shape Transforms +# --------------------------------------------------------------------------- + + +class ConcatKeys(Transform): + def __init__(self, key_list, new_key_name, delete_old_keys=False): + self.key_list = list(key_list) + self.new_key_name = new_key_name + self.delete_old_keys = delete_old_keys + + def transform(self, batch): + arrays = [np.asarray(batch[k]) for k in self.key_list] + try: + batch[self.new_key_name] = np.concatenate(arrays, axis=-1) + except ValueError as e: + shapes = {k: np.asarray(batch[k]).shape for k in self.key_list} + raise ValueError( + f"ConcatKeys failed for keys {self.key_list} with shapes {shapes}" + ) from e + + if self.delete_old_keys: + for k in self.key_list: + batch.pop(k, None) + + return batch + + +# --------------------------------------------------------------------------- +# Transform List Factories +# --------------------------------------------------------------------------- + + +def build_eva_bimanual_transform_list( + *, + left_target_world: str = "left_extrinsics_pose", + right_target_world: str = "right_extrinsics_pose", + left_cmd_world: str = "left.cmd_ee_pose", + right_cmd_world: str = "right.cmd_ee_pose", + left_obs_pose: str = "left.obs_ee_pose", + right_obs_pose: str = "right.obs_ee_pose", + left_obs_gripper: str = "left.obs_gripper", + right_obs_gripper: str = "right.obs_gripper", + left_gripper: str = "left.gripper", + right_gripper: str = "right.gripper", + left_cmd_camframe: str = "left.cmd_ee_pose_camframe", + right_cmd_camframe: str = "right.cmd_ee_pose_camframe", + actions_key: str = "actions_cartesian", + obs_key: str = "observations.state.ee_pose", + chunk_length: int = 100, + stride: int = 1, + is_quat: bool = True, + left_extra_batch_key: dict | None = None, + right_extra_batch_key: dict | None = None, +) -> list[Transform]: + """Canonical EVA bimanual transform pipeline used by tests and notebooks.""" + transform_list: list[Transform] = [ + ActionChunkCoordinateFrameTransform( + target_world=left_target_world, + chunk_world=left_cmd_world, + transformed_key_name=left_cmd_camframe, + extra_batch_key=left_extra_batch_key, + is_quat=is_quat, + ), + ActionChunkCoordinateFrameTransform( + target_world=right_target_world, + chunk_world=right_cmd_world, + transformed_key_name=right_cmd_camframe, + extra_batch_key=right_extra_batch_key, + is_quat=is_quat, + ), + PoseCoordinateFrameTransform( + target_world=left_target_world, + pose_world=left_obs_pose, + transformed_key_name=left_obs_pose, + is_quat=is_quat, + ), + PoseCoordinateFrameTransform( + target_world=right_target_world, + pose_world=right_obs_pose, + transformed_key_name=right_obs_pose, + is_quat=is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_cmd_camframe, + output_action_key=left_cmd_camframe, + stride=stride, + is_quat=is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_cmd_camframe, + output_action_key=right_cmd_camframe, + stride=stride, + is_quat=is_quat, + ), + InterpolateLinear( + new_chunk_length=chunk_length, + action_key=left_gripper, + output_action_key=left_gripper, + stride=stride, + ), + InterpolateLinear( + new_chunk_length=chunk_length, + action_key=right_gripper, + output_action_key=right_gripper, + stride=stride, + ), + ] + if is_quat: + transform_list.append( + XYZWXYZ_to_XYZYPR( + keys=[left_cmd_camframe, right_cmd_camframe, left_obs_pose, right_obs_pose] + ) + ) + + transform_list.extend( + [ + ConcatKeys( + key_list=[ + left_cmd_camframe, + left_gripper, + right_cmd_camframe, + right_gripper, + ], + new_key_name=actions_key, + delete_old_keys=True, + ), + ConcatKeys( + key_list=[ + left_obs_pose, + left_obs_gripper, + right_obs_pose, + right_obs_gripper, + ], + new_key_name=obs_key, + delete_old_keys=True, + ), + DeleteKeys( + keys_to_delete=[ + left_cmd_world, + right_cmd_world, + left_target_world, + right_target_world, + ] + ), + ] + ) + return transform_list + + +def build_aria_bimanual_transform_list( + *, + target_world: str = "obs_head_pose", + target_world_ypr: str = "obs_head_pose_ypr", + target_world_is_quat: bool = True, + left_action_world: str = "left.action_ee_pose", + right_action_world: str = "right.action_ee_pose", + left_obs_pose: str = "left.obs_ee_pose", + right_obs_pose: str = "right.obs_ee_pose", + left_action_headframe: str = "left.action_ee_pose_headframe", + right_action_headframe: str = "right.action_ee_pose_headframe", + left_obs_headframe: str = "left.obs_ee_pose_headframe", + right_obs_headframe: str = "right.obs_ee_pose_headframe", + actions_key: str = "actions_cartesian", + obs_key: str = "observations.state.ee_pose", + chunk_length: int = 100, + stride: int = 3, + delete_target_world: bool = True, +) -> list[Transform]: + """Canonical ARIA bimanual transform pipeline used by tests and notebooks. + + Aria human data does not have commanded ee poses; action chunks are built + from stacked observed ee poses (typically with a horizon on + ``left/right.action_ee_pose`` mapped from ``left/right.obs_ee_pose``). + """ + keys_to_delete = list( + { + left_action_world, + right_action_world, + left_obs_pose, + right_obs_pose, + } + ) + target_pose_key = target_world + if delete_target_world: + keys_to_delete.append(target_world) + if target_world_is_quat: + keys_to_delete.append(target_world_ypr) + + transform_list: list[Transform] = [ + ActionChunkCoordinateFrameTransform( + target_world=target_pose_key, + chunk_world=left_action_world, + transformed_key_name=left_action_headframe, + is_quat=target_world_is_quat, + ), + ActionChunkCoordinateFrameTransform( + target_world=target_pose_key, + chunk_world=right_action_world, + transformed_key_name=right_action_headframe, + is_quat=target_world_is_quat, + ), + PoseCoordinateFrameTransform( + target_world=target_pose_key, + pose_world=left_obs_pose, + transformed_key_name=left_obs_headframe, + is_quat=target_world_is_quat, + ), + PoseCoordinateFrameTransform( + target_world=target_pose_key, + pose_world=right_obs_pose, + transformed_key_name=right_obs_headframe, + is_quat=target_world_is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=left_action_headframe, + output_action_key=left_action_headframe, + stride=stride, + is_quat=target_world_is_quat, + ), + InterpolatePose( + new_chunk_length=chunk_length, + action_key=right_action_headframe, + output_action_key=right_action_headframe, + stride=stride, + is_quat=target_world_is_quat, + ), + ] + + if target_world_is_quat: + transform_list.append( + XYZWXYZ_to_XYZYPR( + keys=[ + left_action_headframe, + right_action_headframe, + left_obs_headframe, + right_obs_headframe, + ] + ) + ) + + transform_list.extend( + [ + ConcatKeys( + key_list=[left_action_headframe, right_action_headframe], + new_key_name=actions_key, + delete_old_keys=True, + ), + ConcatKeys( + key_list=[left_obs_headframe, right_obs_headframe], + new_key_name=obs_key, + delete_old_keys=True, + ), + DeleteKeys(keys_to_delete=keys_to_delete), + ] + ) + return transform_list diff --git a/egomimic/rldb/zarr/test_action_chunk_transforms_unit.py b/egomimic/rldb/zarr/test_action_chunk_transforms_unit.py new file mode 100644 index 00000000..c70c26ef --- /dev/null +++ b/egomimic/rldb/zarr/test_action_chunk_transforms_unit.py @@ -0,0 +1,616 @@ +import numpy as np +import pytest +from scipy.spatial.transform import Rotation as R + +from egomimic.rldb.zarr.action_chunk_transforms import ( + ActionChunkCoordinateFrameTransform, + ConcatKeys, + InterpolatePose, + XYZWXYZ_to_XYZYPR, + build_aria_bimanual_transform_list, + build_eva_bimanual_transform_list, +) +from egomimic.utils.pose_utils import _xyzwxyz_to_matrix + + +def _shape_map(batch: dict) -> dict[str, tuple]: + return {k: tuple(np.asarray(v).shape) for k, v in batch.items()} + + +def _run_and_capture(transform_list, batch: dict): + snapshots = [] + data = {k: np.asarray(v).copy() for k, v in batch.items()} + for transform in transform_list: + data = transform.transform(data) + snapshots.append( + (transform.__class__.__name__, set(data.keys()), _shape_map(data)) + ) + return snapshots + + +def _assert_snapshot( + snapshots, + idx: int, + expected_name: str, + expected_keys: set[str], + expected_shapes: dict[str, tuple], +) -> None: + name, keys, shapes = snapshots[idx] + assert name == expected_name, ( + f"step {idx}: transform mismatch; expected {expected_name}, got {name}" + ) + + missing = expected_keys - keys + extra = keys - expected_keys + assert not missing and not extra, ( + f"step {idx} ({name}): key set mismatch; missing={sorted(missing)}, " + f"extra={sorted(extra)}" + ) + + shape_mismatch = { + k: (expected_shapes[k], shapes.get(k)) + for k in expected_shapes + if shapes.get(k) != expected_shapes[k] + } + assert not shape_mismatch, ( + f"step {idx} ({name}): shape mismatch {shape_mismatch}" + ) + + +def test_xyzwxyz_to_matrix_converts_wxyz_quaternion() -> None: + yaw = np.pi / 2.0 + qw = np.cos(yaw / 2.0) + qz = np.sin(yaw / 2.0) + + poses = np.array( + [ + [1.0, 2.0, 3.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, qw, 0.0, 0.0, qz], + ], + dtype=np.float64, + ) + + mats = _xyzwxyz_to_matrix(poses) + + np.testing.assert_allclose(mats[0, :3, 3], np.array([1.0, 2.0, 3.0])) + np.testing.assert_allclose(mats[0, :3, :3], np.eye(3), atol=1e-7) + + expected_rot = R.from_euler("Z", yaw, degrees=False).as_matrix() + np.testing.assert_allclose(mats[1, :3, :3], expected_rot, atol=1e-7) + np.testing.assert_allclose(mats[1, :3, 3], np.zeros(3), atol=1e-7) + + +def test_action_chunk_coordinate_frame_transform_accepts_quat_wxyz_input() -> None: + yaw = np.pi / 2.0 + qw = np.cos(yaw / 2.0) + qz = np.sin(yaw / 2.0) + + transform = ActionChunkCoordinateFrameTransform( + target_world="target_world", + chunk_world="chunk_world", + transformed_key_name="chunk_target", + is_quat=True, + ) + + batch = { + "target_world": np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), + "chunk_world": np.array( + [ + [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, qw, 0.0, 0.0, qz], + ], + dtype=np.float64, + ), + } + + out = transform.transform(batch) + chunk_target = np.asarray(out["chunk_target"]) + + assert chunk_target.shape == (2, 7) + expected = np.array( + [ + [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, qw, 0.0, 0.0, qz], + ], + dtype=np.float64, + ) + np.testing.assert_allclose(chunk_target, expected, atol=1e-6) + + +def test_action_chunk_coordinate_frame_transform_ypr_mode_invalid_chunk_shape_raises() -> None: + transform = ActionChunkCoordinateFrameTransform( + target_world="target_world", + chunk_world="chunk_world", + transformed_key_name="chunk_target", + is_quat=False, + ) + batch = { + "target_world": np.zeros(6, dtype=np.float64), + "chunk_world": np.zeros((2, 7), dtype=np.float64), + } + + with pytest.raises(ValueError, match=r"Expected \(B, 6\) array"): + transform.transform(batch) + + +def test_action_chunk_coordinate_frame_transform_quat_mode_invalid_chunk_shape_raises() -> None: + transform = ActionChunkCoordinateFrameTransform( + target_world="target_world", + chunk_world="chunk_world", + transformed_key_name="chunk_target", + is_quat=True, + ) + batch = { + "target_world": np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], dtype=np.float64), + "chunk_world": np.zeros((2, 6), dtype=np.float64), + } + + with pytest.raises(ValueError, match=r"Expected \(B, 7\) array"): + transform.transform(batch) + + +def test_action_chunk_coordinate_frame_transform_invalid_target_shape_raises() -> None: + transform = ActionChunkCoordinateFrameTransform( + target_world="target_world", + chunk_world="chunk_world", + transformed_key_name="chunk_target", + is_quat=False, + ) + batch = { + "target_world": np.zeros((2, 6), dtype=np.float64), + "chunk_world": np.zeros((2, 6), dtype=np.float64), + } + + with pytest.raises(ValueError, match=r"Expected \(B, 6\) array"): + transform.transform(batch) + + +def test_interpolate_pose_quat_wxyz_slerp_happy_path() -> None: + yaw = np.pi / 2.0 + qw = np.cos(yaw / 2.0) + qz = np.sin(yaw / 2.0) + + transform = InterpolatePose( + new_chunk_length=5, + action_key="actions", + output_action_key="actions_out", + is_quat=True, + ) + batch = { + "actions": np.array( + [ + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 0.0, 0.0, qw, 0.0, 0.0, qz], + ], + dtype=np.float64, + ) + } + + out = np.asarray(transform.transform(batch)["actions_out"]) + assert out.shape == (5, 7) + + quat_norms = np.linalg.norm(out[:, 3:7], axis=1) + np.testing.assert_allclose(quat_norms, np.ones(5), atol=1e-7) + np.testing.assert_allclose(out[0], batch["actions"][0], atol=1e-7) + np.testing.assert_allclose(out[-1], batch["actions"][-1], atol=1e-7) + + yaws = R.from_quat(out[:, [4, 5, 6, 3]]).as_euler("ZYX", degrees=False)[:, 0] + np.testing.assert_allclose(yaws[0], 0.0, atol=1e-7) + np.testing.assert_allclose(yaws[-1], yaw, atol=1e-7) + assert np.all(np.diff(yaws) >= -1e-7) + + +def test_interpolate_pose_quat_wxyz_invalid_shape_raises() -> None: + transform = InterpolatePose( + new_chunk_length=4, + action_key="actions", + output_action_key="actions_out", + is_quat=True, + ) + batch = {"actions": np.zeros((2, 6), dtype=np.float64)} + with pytest.raises(ValueError, match=r"InterpolatePose expects \(T, 7\)"): + transform.transform(batch) + + +def test_xyzwxyz_to_xyzypr_single_pose_conversion() -> None: + yaw = np.pi / 2.0 + qw = np.cos(yaw / 2.0) + qz = np.sin(yaw / 2.0) + + transform = XYZWXYZ_to_XYZYPR(keys=["pose"]) + batch = {"pose": np.array([1.0, 2.0, 3.0, qw, 0.0, 0.0, qz], dtype=np.float64)} + + out = np.asarray(transform.transform(batch)["pose"]) + assert out.shape == (6,) + np.testing.assert_allclose(out[:3], np.array([1.0, 2.0, 3.0]), atol=1e-7) + np.testing.assert_allclose(out[3:], np.array([yaw, 0.0, 0.0]), atol=1e-6) + + +def test_xyzwxyz_to_xyzypr_chunk_conversion() -> None: + yaw = np.pi / 2.0 + qw = np.cos(yaw / 2.0) + qz = np.sin(yaw / 2.0) + + transform = XYZWXYZ_to_XYZYPR(keys=["poses"]) + batch = { + "poses": np.array( + [ + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, qw, 0.0, 0.0, qz], + ], + dtype=np.float64, + ) + } + + out = np.asarray(transform.transform(batch)["poses"]) + assert out.shape == (2, 6) + np.testing.assert_allclose(out[0, 3:], np.zeros(3), atol=1e-7) + np.testing.assert_allclose(out[1, 3:], np.array([yaw, 0.0, 0.0]), atol=1e-6) + + +def test_xyzwxyz_to_xyzypr_strict_shape_raises() -> None: + transform = XYZWXYZ_to_XYZYPR(keys=["poses"]) + batch = {"poses": np.zeros((2, 6), dtype=np.float64)} + with pytest.raises( + ValueError, match=r"XYZWXYZ_to_XYZYPR expects key 'poses' to have shape" + ): + transform.transform(batch) + + +def test_eva_builder_orders_xyzwxyz_to_xyzypr_after_interpolate_before_concat() -> None: + transform_list = build_eva_bimanual_transform_list(is_quat=True) + converter_indices = [ + i for i, t in enumerate(transform_list) if isinstance(t, XYZWXYZ_to_XYZYPR) + ] + interpolate_indices = [ + i for i, t in enumerate(transform_list) if isinstance(t, InterpolatePose) + ] + concat_indices = [i for i, t in enumerate(transform_list) if isinstance(t, ConcatKeys)] + + assert len(converter_indices) == 1 + converter_idx = converter_indices[0] + assert converter_idx > max(interpolate_indices) + assert converter_idx < min(concat_indices) + assert set(transform_list[converter_idx].keys) == { + "left.cmd_ee_pose_camframe", + "right.cmd_ee_pose_camframe", + "left.obs_ee_pose", + "right.obs_ee_pose", + } + + +def test_aria_builder_orders_xyzwxyz_to_xyzypr_after_interpolate_before_concat() -> None: + transform_list = build_aria_bimanual_transform_list(target_world_is_quat=True) + converter_indices = [ + i for i, t in enumerate(transform_list) if isinstance(t, XYZWXYZ_to_XYZYPR) + ] + interpolate_indices = [ + i for i, t in enumerate(transform_list) if isinstance(t, InterpolatePose) + ] + concat_indices = [i for i, t in enumerate(transform_list) if isinstance(t, ConcatKeys)] + + assert len(converter_indices) == 1 + converter_idx = converter_indices[0] + assert converter_idx > max(interpolate_indices) + assert converter_idx < min(concat_indices) + assert set(transform_list[converter_idx].keys) == { + "left.action_ee_pose_headframe", + "right.action_ee_pose_headframe", + "left.obs_ee_pose_headframe", + "right.obs_ee_pose_headframe", + } + + +def test_eva_transform_list_stepwise_keys_and_shapes() -> None: + transform_list = build_eva_bimanual_transform_list( + chunk_length=4, stride=1, is_quat=True + ) + cmd_pose = np.zeros((5, 7), dtype=np.float64) + cmd_pose[:, 3] = 1.0 + obs_pose = np.zeros((7,), dtype=np.float64) + obs_pose[3] = 1.0 + batch = { + "left_extrinsics_pose": np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]), + "right_extrinsics_pose": np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]), + "left.cmd_ee_pose": cmd_pose.copy(), + "right.cmd_ee_pose": cmd_pose.copy(), + "left.obs_ee_pose": obs_pose.copy(), + "right.obs_ee_pose": obs_pose.copy(), + "left.gripper": np.zeros((5, 1), dtype=np.float64), + "right.gripper": np.zeros((5, 1), dtype=np.float64), + "left.obs_gripper": np.zeros((1,), dtype=np.float64), + "right.obs_gripper": np.zeros((1,), dtype=np.float64), + } + snapshots = _run_and_capture(transform_list, batch) + + expected_names = [ + "ActionChunkCoordinateFrameTransform", + "ActionChunkCoordinateFrameTransform", + "PoseCoordinateFrameTransform", + "PoseCoordinateFrameTransform", + "InterpolatePose", + "InterpolatePose", + "InterpolateLinear", + "InterpolateLinear", + "XYZWXYZ_to_XYZYPR", + "ConcatKeys", + "ConcatKeys", + "DeleteKeys", + ] + assert [name for name, _, _ in snapshots] == expected_names + + base_keys = { + "left_extrinsics_pose", + "right_extrinsics_pose", + "left.cmd_ee_pose", + "right.cmd_ee_pose", + "left.obs_ee_pose", + "right.obs_ee_pose", + "left.gripper", + "right.gripper", + "left.obs_gripper", + "right.obs_gripper", + } + + _assert_snapshot( + snapshots, + 0, + "ActionChunkCoordinateFrameTransform", + base_keys | {"left.cmd_ee_pose_camframe"}, + { + "left_extrinsics_pose": (7,), + "right_extrinsics_pose": (7,), + "left.cmd_ee_pose": (5, 7), + "right.cmd_ee_pose": (5, 7), + "left.obs_ee_pose": (7,), + "right.obs_ee_pose": (7,), + "left.gripper": (5, 1), + "right.gripper": (5, 1), + "left.obs_gripper": (1,), + "right.obs_gripper": (1,), + "left.cmd_ee_pose_camframe": (5, 7), + }, + ) + _assert_snapshot( + snapshots, + 1, + "ActionChunkCoordinateFrameTransform", + base_keys | {"left.cmd_ee_pose_camframe", "right.cmd_ee_pose_camframe"}, + { + "left.cmd_ee_pose_camframe": (5, 7), + "right.cmd_ee_pose_camframe": (5, 7), + "left.obs_ee_pose": (7,), + "right.obs_ee_pose": (7,), + }, + ) + _assert_snapshot( + snapshots, + 4, + "InterpolatePose", + base_keys | {"left.cmd_ee_pose_camframe", "right.cmd_ee_pose_camframe"}, + { + "left.cmd_ee_pose_camframe": (4, 7), + "right.cmd_ee_pose_camframe": (5, 7), + }, + ) + _assert_snapshot( + snapshots, + 5, + "InterpolatePose", + base_keys | {"left.cmd_ee_pose_camframe", "right.cmd_ee_pose_camframe"}, + { + "left.cmd_ee_pose_camframe": (4, 7), + "right.cmd_ee_pose_camframe": (4, 7), + }, + ) + _assert_snapshot( + snapshots, + 8, + "XYZWXYZ_to_XYZYPR", + base_keys | {"left.cmd_ee_pose_camframe", "right.cmd_ee_pose_camframe"}, + { + "left.cmd_ee_pose_camframe": (4, 6), + "right.cmd_ee_pose_camframe": (4, 6), + "left.obs_ee_pose": (6,), + "right.obs_ee_pose": (6,), + }, + ) + _assert_snapshot( + snapshots, + 9, + "ConcatKeys", + { + "actions_cartesian", + "left_extrinsics_pose", + "right_extrinsics_pose", + "left.cmd_ee_pose", + "right.cmd_ee_pose", + "left.obs_ee_pose", + "right.obs_ee_pose", + "left.obs_gripper", + "right.obs_gripper", + }, + { + "actions_cartesian": (4, 14), + "left.obs_ee_pose": (6,), + "right.obs_ee_pose": (6,), + }, + ) + _assert_snapshot( + snapshots, + 10, + "ConcatKeys", + { + "actions_cartesian", + "observations.state.ee_pose", + "left_extrinsics_pose", + "right_extrinsics_pose", + "left.cmd_ee_pose", + "right.cmd_ee_pose", + }, + { + "actions_cartesian": (4, 14), + "observations.state.ee_pose": (14,), + }, + ) + _assert_snapshot( + snapshots, + 11, + "DeleteKeys", + {"actions_cartesian", "observations.state.ee_pose"}, + { + "actions_cartesian": (4, 14), + "observations.state.ee_pose": (14,), + }, + ) + + +def test_aria_transform_list_stepwise_keys_and_shapes() -> None: + transform_list = build_aria_bimanual_transform_list( + chunk_length=4, + stride=2, + target_world_is_quat=True, + left_action_world="left.action_ee_pose", + right_action_world="right.action_ee_pose", + ) + action_pose = np.zeros((6, 7), dtype=np.float64) + action_pose[:, 3] = 1.0 + obs_pose = np.zeros((7,), dtype=np.float64) + obs_pose[3] = 1.0 + batch = { + "obs_head_pose": np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]), + "left.action_ee_pose": action_pose.copy(), + "right.action_ee_pose": action_pose.copy(), + "left.obs_ee_pose": obs_pose.copy(), + "right.obs_ee_pose": obs_pose.copy(), + } + snapshots = _run_and_capture(transform_list, batch) + + expected_names = [ + "ActionChunkCoordinateFrameTransform", + "ActionChunkCoordinateFrameTransform", + "PoseCoordinateFrameTransform", + "PoseCoordinateFrameTransform", + "InterpolatePose", + "InterpolatePose", + "XYZWXYZ_to_XYZYPR", + "ConcatKeys", + "ConcatKeys", + "DeleteKeys", + ] + assert [name for name, _, _ in snapshots] == expected_names + + base_keys = { + "obs_head_pose", + "left.action_ee_pose", + "right.action_ee_pose", + "left.obs_ee_pose", + "right.obs_ee_pose", + } + _assert_snapshot( + snapshots, + 0, + "ActionChunkCoordinateFrameTransform", + base_keys | {"left.action_ee_pose_headframe"}, + { + "left.action_ee_pose_headframe": (6, 7), + "left.obs_ee_pose": (7,), + "right.obs_ee_pose": (7,), + }, + ) + _assert_snapshot( + snapshots, + 1, + "ActionChunkCoordinateFrameTransform", + base_keys | {"left.action_ee_pose_headframe", "right.action_ee_pose_headframe"}, + { + "left.action_ee_pose_headframe": (6, 7), + "right.action_ee_pose_headframe": (6, 7), + }, + ) + _assert_snapshot( + snapshots, + 4, + "InterpolatePose", + base_keys + | { + "left.action_ee_pose_headframe", + "right.action_ee_pose_headframe", + "left.obs_ee_pose_headframe", + "right.obs_ee_pose_headframe", + }, + { + "left.action_ee_pose_headframe": (4, 7), + "right.action_ee_pose_headframe": (6, 7), + }, + ) + _assert_snapshot( + snapshots, + 5, + "InterpolatePose", + base_keys + | { + "left.action_ee_pose_headframe", + "right.action_ee_pose_headframe", + "left.obs_ee_pose_headframe", + "right.obs_ee_pose_headframe", + }, + { + "left.action_ee_pose_headframe": (4, 7), + "right.action_ee_pose_headframe": (4, 7), + }, + ) + _assert_snapshot( + snapshots, + 6, + "XYZWXYZ_to_XYZYPR", + base_keys + | { + "left.action_ee_pose_headframe", + "right.action_ee_pose_headframe", + "left.obs_ee_pose_headframe", + "right.obs_ee_pose_headframe", + }, + { + "left.action_ee_pose_headframe": (4, 6), + "right.action_ee_pose_headframe": (4, 6), + "left.obs_ee_pose_headframe": (6,), + "right.obs_ee_pose_headframe": (6,), + }, + ) + _assert_snapshot( + snapshots, + 7, + "ConcatKeys", + base_keys + | { + "left.obs_ee_pose_headframe", + "right.obs_ee_pose_headframe", + "actions_cartesian", + }, + { + "actions_cartesian": (4, 12), + "left.obs_ee_pose_headframe": (6,), + "right.obs_ee_pose_headframe": (6,), + }, + ) + _assert_snapshot( + snapshots, + 8, + "ConcatKeys", + base_keys | {"actions_cartesian", "observations.state.ee_pose"}, + { + "actions_cartesian": (4, 12), + "observations.state.ee_pose": (12,), + }, + ) + _assert_snapshot( + snapshots, + 9, + "DeleteKeys", + {"actions_cartesian", "observations.state.ee_pose"}, + { + "actions_cartesian": (4, 12), + "observations.state.ee_pose": (12,), + }, + ) diff --git a/egomimic/rldb/zarr/test_zarr.py b/egomimic/rldb/zarr/test_zarr.py new file mode 100644 index 00000000..7b484584 --- /dev/null +++ b/egomimic/rldb/zarr/test_zarr.py @@ -0,0 +1,290 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch + +from egomimic.rldb.utils import S3RLDBDataset +from egomimic.rldb.zarr.action_chunk_transforms import ( + build_aria_bimanual_transform_list, + build_eva_bimanual_transform_list, + _matrix_to_xyzypr, +) +from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset +from egomimic.utils.egomimicUtils import EXTRINSICS + + +ZARR_EPISODE_PATH = Path( + "/coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/new/1769460905119.zarr" +) +LEROBOT_EPISODE_HASH = "2026-01-26-20-55-05-119000" +LEROBOT_CACHE_ROOT = "/coc/flash7/skareer6/CacheEgoVerse/.cache" +EMBODIMENT = "eva_bimanual" +ACTION_HORIZON_REAL = 45 +ACTION_CHUNK_LENGTH = 100 +KEYS_TO_COMPARE = ( + "observations.images.front_img_1", + "observations.images.right_wrist_img", + "observations.images.left_wrist_img", + "actions_cartesian", + "observations.state.ee_pose", +) +IMAGE_KEYS = { + "observations.images.front_img_1", + "observations.images.right_wrist_img", + "observations.images.left_wrist_img", +} + +ARIA_ZARR_EPISODE_PATH = Path( + "/coc/flash7/scratch/egoverseDebugDatasets/proc_zarr/1764285211791.zarr/" +) +ARIA_LEROBOT_EPISODE_HASH = "2025-11-27-23-13-31-791000" +ARIA_EMBODIMENT = "aria_bimanual" +ARIA_ACTION_HORIZON_REAL = 30 +ARIA_ACTION_CHUNK_LENGTH = 100 +ARIA_ACTION_STRIDE = 3 +ARIA_KEYS_TO_COMPARE = ( + "observations.images.front_img_1", + "actions_cartesian", + "observations.state.ee_pose", +) +ARIA_IMAGE_KEYS = {"observations.images.front_img_1"} + + +def _to_numpy(value): + if isinstance(value, torch.Tensor): + return value.detach().cpu().numpy() + if isinstance(value, np.ndarray): + return value + return value + + +def _check_equal_dict(left: dict, right: dict, path: str = "root") -> None: + assert set(left.keys()) == set(right.keys()), ( + f"{path}: key mismatch. left_only={set(left.keys()) - set(right.keys())}, " + f"right_only={set(right.keys()) - set(left.keys())}" + ) + + for key in left: + left_value = left[key] + right_value = right[key] + key_path = f"{path}.{key}" + + if isinstance(left_value, dict) and isinstance(right_value, dict): + _check_equal_dict(left_value, right_value, key_path) + continue + + left_np = _to_numpy(left_value) + right_np = _to_numpy(right_value) + + if isinstance(left_np, np.ndarray) or isinstance(right_np, np.ndarray): + assert isinstance(left_np, np.ndarray) and isinstance(right_np, np.ndarray), ( + f"{key_path}: expected both values to be tensor/ndarray, " + f"got {type(left_value)} vs {type(right_value)}" + ) + assert left_np.shape == right_np.shape, ( + f"{key_path}: shape mismatch {left_np.shape} vs {right_np.shape}" + ) + + if np.issubdtype(left_np.dtype, np.floating) or np.issubdtype( + right_np.dtype, np.floating + ): + np.testing.assert_allclose( + left_np, + right_np, + rtol=1e-5, + atol=1e-5, + err_msg=f"{key_path}: floating values differ", + ) + else: + np.testing.assert_array_equal( + left_np, right_np, err_msg=f"{key_path}: values differ" + ) + continue + + assert left_value == right_value, ( + f"{key_path}: value mismatch {left_value!r} vs {right_value!r}" + ) + + +def _build_zarr_dataset_eva() -> MultiDataset: + key_map = { + "observations.images.front_img_1": {"zarr_key": "images.front_1"}, + "observations.images.right_wrist_img": {"zarr_key": "images.right_wrist"}, + "observations.images.left_wrist_img": {"zarr_key": "images.left_wrist"}, + "right.obs_ee_pose": {"zarr_key": "right.obs_ee_pose"}, + "right.obs_gripper": {"zarr_key": "right.gripper"}, + "left.obs_ee_pose": {"zarr_key": "left.obs_ee_pose"}, + "left.obs_gripper": {"zarr_key": "left.gripper"}, + "right.gripper": {"zarr_key": "right.gripper", "horizon": ACTION_HORIZON_REAL}, + "left.gripper": {"zarr_key": "left.gripper", "horizon": ACTION_HORIZON_REAL}, + "right.cmd_ee_pose": {"zarr_key": "right.cmd_ee_pose", "horizon": ACTION_HORIZON_REAL}, + "left.cmd_ee_pose": {"zarr_key": "left.cmd_ee_pose", "horizon": ACTION_HORIZON_REAL}, + } + + extrinsics = EXTRINSICS["x5Dec13_2"] + left_extrinsics_pose = _matrix_to_xyzypr(extrinsics["left"][None, :])[0] + right_extrinsics_pose = _matrix_to_xyzypr(extrinsics["right"][None, :])[0] + + transform_list = build_eva_bimanual_transform_list( + chunk_length=ACTION_CHUNK_LENGTH, + stride=1, + left_extra_batch_key={"left_extrinsics_pose": left_extrinsics_pose}, + right_extra_batch_key={"right_extrinsics_pose": right_extrinsics_pose}, + ) + + single_dataset = ZarrDataset( + Episode_path=ZARR_EPISODE_PATH, + key_map=key_map, + transform_list=transform_list, + ) + return MultiDataset(datasets={"single_episode": single_dataset}, mode="total") + + +def _build_lerobot_dataset() -> S3RLDBDataset: + return S3RLDBDataset( + filters={"episode_hash": LEROBOT_EPISODE_HASH}, + mode="total", + cache_root=LEROBOT_CACHE_ROOT, + embodiment=EMBODIMENT, + ) + + +def _build_zarr_dataset_aria() -> MultiDataset: + key_map = { + "observations.images.front_img_1": {"zarr_key": "images.front_1"}, + "left.obs_ee_pose": {"zarr_key": "left.obs_ee_pose"}, + "right.obs_ee_pose": {"zarr_key": "right.obs_ee_pose"}, + "left.action_ee_pose": {"zarr_key": "left.obs_ee_pose", "horizon": ARIA_ACTION_HORIZON_REAL}, + "right.action_ee_pose": {"zarr_key": "right.obs_ee_pose", "horizon": ARIA_ACTION_HORIZON_REAL}, + "obs_head_pose": {"zarr_key": "obs_head_pose"}, + } + + transform_list = build_aria_bimanual_transform_list( + chunk_length=ARIA_ACTION_CHUNK_LENGTH, + stride=ARIA_ACTION_STRIDE, + left_action_world="left.action_ee_pose", + right_action_world="right.action_ee_pose", + actions_key="actions_cartesian", + obs_key="observations.state.ee_pose", + ) + + single_dataset = ZarrDataset( + Episode_path=ARIA_ZARR_EPISODE_PATH, + key_map=key_map, + transform_list=transform_list, + ) + return MultiDataset(datasets={"single_episode": single_dataset}, mode="total") + + +def _build_lerobot_dataset_aria() -> S3RLDBDataset: + return S3RLDBDataset( + filters={"episode_hash": ARIA_LEROBOT_EPISODE_HASH}, + mode="total", + cache_root=LEROBOT_CACHE_ROOT, + embodiment=ARIA_EMBODIMENT, + ) + + +def _first_batch(dataset: torch.utils.data.Dataset) -> dict: + loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) + return next(iter(loader)) + + +def test_zarr_batch_matches_lerobot_batch_eva() -> None: + if not ZARR_EPISODE_PATH.exists(): + pytest.skip(f"Zarr path not found: {ZARR_EPISODE_PATH}") + + zarr_batch = _first_batch(_build_zarr_dataset_eva()) + lerobot_batch = _first_batch(_build_lerobot_dataset()) + + missing_keys = [key for key in KEYS_TO_COMPARE if key not in lerobot_batch] + assert not missing_keys, ( + f"Lerobot batch missing keys required for comparison: {missing_keys}" + ) + + missing_zarr_keys = [key for key in KEYS_TO_COMPARE if key not in zarr_batch] + assert not missing_zarr_keys, ( + f"Zarr batch missing keys required for comparison: {missing_zarr_keys}" + ) + + zarr_subset = {key: zarr_batch[key] for key in KEYS_TO_COMPARE} + lerobot_subset = {key: lerobot_batch[key] for key in KEYS_TO_COMPARE} + + for key in IMAGE_KEYS: + lerobot_arr = _to_numpy(lerobot_subset[key]) + zarr_arr = _to_numpy(zarr_subset[key]) + assert isinstance(lerobot_arr, np.ndarray) and isinstance(zarr_arr, np.ndarray), ( + f"{key}: expected array/tensor values, got {type(lerobot_subset[key])} " + f"and {type(zarr_subset[key])}" + ) + assert lerobot_arr.shape == zarr_arr.shape, ( + f"{key}: image shape mismatch {lerobot_arr.shape} vs {zarr_arr.shape}" + ) + + non_image_lerobot = { + key: value for key, value in lerobot_subset.items() if key not in IMAGE_KEYS + } + non_image_zarr = { + key: value for key, value in zarr_subset.items() if key not in IMAGE_KEYS + } + + lerobot_actions = _to_numpy(non_image_lerobot.pop("actions_cartesian")) + zarr_actions = _to_numpy(non_image_zarr.pop("actions_cartesian")) + assert isinstance(lerobot_actions, np.ndarray) and isinstance(zarr_actions, np.ndarray), ( + "actions_cartesian must be tensors/arrays" + ) + assert lerobot_actions.shape == zarr_actions.shape, ( + f"actions_cartesian shape mismatch: {lerobot_actions.shape} vs {zarr_actions.shape}" + ) + + np.testing.assert_allclose( + lerobot_actions, + zarr_actions, + rtol=0.0, + atol=2e-3, + err_msg="actions_cartesian mismatch", + ) + + _check_equal_dict(non_image_lerobot, non_image_zarr) + + +def test_zarr_batch_matches_lerobot_batch_aria() -> None: + if not ARIA_ZARR_EPISODE_PATH.exists(): + pytest.skip(f"Aria zarr path not found: {ARIA_ZARR_EPISODE_PATH}") + + zarr_batch = _first_batch(_build_zarr_dataset_aria()) + lerobot_batch = _first_batch(_build_lerobot_dataset_aria()) + + missing_keys = [key for key in ARIA_KEYS_TO_COMPARE if key not in lerobot_batch] + assert not missing_keys, ( + f"Lerobot Aria batch missing keys required for comparison: {missing_keys}" + ) + + missing_zarr_keys = [key for key in ARIA_KEYS_TO_COMPARE if key not in zarr_batch] + assert not missing_zarr_keys, ( + f"Zarr Aria batch missing keys required for comparison: {missing_zarr_keys}" + ) + + zarr_subset = {key: zarr_batch[key] for key in ARIA_KEYS_TO_COMPARE} + lerobot_subset = {key: lerobot_batch[key] for key in ARIA_KEYS_TO_COMPARE} + + for key in ARIA_IMAGE_KEYS: + lerobot_arr = _to_numpy(lerobot_subset[key]) + zarr_arr = _to_numpy(zarr_subset[key]) + assert isinstance(lerobot_arr, np.ndarray) and isinstance(zarr_arr, np.ndarray), ( + f"{key}: expected array/tensor values, got {type(lerobot_subset[key])} " + f"and {type(zarr_subset[key])}" + ) + assert lerobot_arr.shape == zarr_arr.shape, ( + f"{key}: image shape mismatch {lerobot_arr.shape} vs {zarr_arr.shape}" + ) + + non_image_lerobot = { + key: value for key, value in lerobot_subset.items() if key not in ARIA_IMAGE_KEYS + } + non_image_zarr = { + key: value for key, value in zarr_subset.items() if key not in ARIA_IMAGE_KEYS + } + + _check_equal_dict(non_image_lerobot, non_image_zarr) diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py new file mode 100644 index 00000000..f8d6f7e7 --- /dev/null +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -0,0 +1,761 @@ +""" +ZarrDataset implementation for EgoVerse. + +Mirrors the LeRobotDataset API while reading data from Zarr arrays +instead of parquet/HF datasets. + +Directory structure (per-episode metadata): + dataset_root/ + └── episode_{ep_idx}.zarr/ + ├── observations.images.{cam} (JPEG compressed) + ├── observations.state + ├── actions_joints + └── ... + +Each episode is self-contained with its own metadata, enabling: +- Independent episode uploads to S3 +- Parallel processing without global coordination +- Easy episode-level data management +""" + +from __future__ import annotations +import json +import logging +import random +from pathlib import Path +from tracemalloc import start +from turtle import st +import pandas as pd +import numpy as np +import torch +import zarr +import subprocess +import tempfile +from datasets import concatenate_datasets +from enum import Enum +import simplejpeg +# from action_chunk_transforms import Transform + + +from egomimic.utils.aws.aws_sql import ( + create_default_engine, + episode_table_to_df, +) + +logger = logging.getLogger(__name__) + +SEED = 42 + +def split_dataset_names(dataset_names, valid_ratio=0.2, seed=SEED): + """ + Split a list of dataset names into train/valid sets. + Args: + dataset_names (Iterable[str]) + valid_ratio (float): fraction of datasets to put in valid. + seed (int): for deterministic shuffling. + + + Returns: + train_set (set[str]), valid_set (set[str]) + """ + names = sorted(dataset_names) + if not names: + return set(), set() + + rng = random.Random(seed) + rng.shuffle(names) + + if not (0.0 <= valid_ratio <= 1.0): + raise ValueError(f"valid_ratio must be in [0,1], got {valid_ratio}") + + n_valid = int(len(names) * valid_ratio) + if valid_ratio > 0.0: + n_valid = max(1, n_valid) + + valid = set(names[:n_valid]) + train = set(names[n_valid:]) + return train, valid +class EpisodeResolver: + """ + Base class for episode resolution utilities. + Provides shared static/class helpers; subclasses implement resolve(). + """ + def __init__( + self, + folder_path: Path, + key_map: dict | None = None, + transform_list: list | None = None, + ): + self.folder_path = Path(folder_path) + self.key_map = key_map + self.transform_list = transform_list + + def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]): + + """ + Loads multiple Zarr datasets from the specified folder path, filtering only those whose hashes + are present in the valid_folder_names set. + + Args: + search_path (Path): The root directory to search for Zarr datasets. + valid_folder_names (set[str]): A set of valid folder names (episode hashes without ".zarr") to filter datasets. + Returns: + dict[str, ZarrDataset]: a dictionary mapping string keys to constructed zarr datasets from valid filters. + """ + all_paths = sorted(search_path.iterdir()) + datasets: dict[str, ZarrDataset] = {} + skipped: list[str] = [] + for p in all_paths: + if not p.is_dir(): + logger.info(f"{p} is not a valid directory") + skipped.append(p.name) + continue + name = p.name + if name.endswith(".zarr"): + name = name[: -len(".zarr")] + if name not in valid_folder_names: + logger.info(f"{p} is not in the list of filtered paths") + skipped.append(p.name) + continue + try: + ds_obj = ZarrDataset(p, key_map=self.key_map, transform_list=self.transform_list) + datasets[name] = ds_obj + except Exception as e: + logger.error(f"Failed to load dataset at {p}: {e}") + skipped.append(p.name) + + return datasets + + @classmethod + def _episode_already_present(cls, local_dir: Path, episode_hash: str) -> bool: + direct = local_dir / episode_hash + if direct.is_dir(): + return True + + + +class S3EpisodeResolver(EpisodeResolver): + """ + Resolves episodes via SQL table and optionally syncs from S3. + """ + def __init__( + self, + folder_path: Path, + bucket_name: str = "rldb", + main_prefix: str = "processed_v2", + key_map: dict | None = None, + transform_list: list | None = None, + ): + self.bucket_name = bucket_name + self.main_prefix = main_prefix + super().__init__(folder_path, key_map=key_map, transform_list=transform_list) + + def resolve( + self, + filters: dict = {}, + ) -> list[tuple[str, str]]: + """ + Outputs a list of ZarrDatasets with relevant filters. + If sync_from_s3 is True, sync S3 paths to local_root before indexing. + If not True, assumes folders already exist locally. + """ + filters["is_deleted"] = False + + if self.folder_path.is_dir(): + logger.info(f"Using existing directory: {self.folder_path}") + if not self.folder_path.is_dir(): + self.folder_path.mkdir() + + logger.info(f"Filters: {filters}") + + filtered_paths = self.sync_from_filters( + bucket_name=self.bucket_name, + filters=filters, + local_dir=self.folder_path, + ) + + valid_hashes = {hashes for _, hashes in filtered_paths} + if not valid_hashes: + raise ValueError( + "No valid collection names from _get_filtered_paths: " + "filters matched no episodes in the SQL table." + ) + + datasets = self._load_zarr_datasets( + search_path=self.folder_path, + valid_folder_names=valid_hashes, + ) + + return datasets + + @staticmethod + def _get_filtered_paths(filters: dict = {}) -> list[tuple[str, str]]: + """ + Filters episodes from the SQL episode table according to the criteria specified in `filters` + and returns a list of (processed_path, episode_hash) tuples for episodes that match and have + a non-null processed_path. + + Args: + filters (dict): Dictionary of filter key-value pairs to apply on the episode table. + + Returns: + list[tuple[str, str]]: List of tuples, each containing (processed_path, episode_hash) + for episodes passing the filter criteria. + """ + engine = create_default_engine() + df = episode_table_to_df(engine) + series = pd.Series(filters) + + output = df.loc[ + (df[list(filters)] == series).all(axis=1), + ["processed_path", "episode_hash"], + ] + skipped = df[df["processed_path"].isnull()]["episode_hash"].tolist() + logger.info( + f"Skipped {len(skipped)} episodes with null processed_path: {skipped}" + ) + output = output[~output["episode_hash"].isin(skipped)] + + paths = list(output.itertuples(index=False, name=None)) + logger.info(f"Paths: {paths}") + return paths + + + @classmethod + def _sync_s3_to_local(cls, bucket_name: str, s3_paths: list[tuple[str, str]], local_dir: Path): + if not s3_paths: + return + + # 0) Skip episodes already present locally + to_sync = [] + already = [] + for processed_path, episode_hash in s3_paths: + if cls._episode_already_present(local_dir, episode_hash): + already.append(episode_hash) + else: + to_sync.append((processed_path, episode_hash)) + + if already: + logger.info("Skipping %d episodes already present locally.", len(already)) + + if not to_sync: + logger.info("Nothing to sync from S3 (all episodes already present).") + return + + # 1) Build s5cmd batch script (one line per episode) + local_dir.mkdir(parents=True, exist_ok=True) + with tempfile.NamedTemporaryFile( + prefix="_s5cmd_sync_", + suffix=".txt", + delete=False, + ) as tmp_file: + batch_path = Path(tmp_file.name) + + lines = [] + for processed_path, episode_hash in to_sync: + # processed_path like: s3://rldb/processed_v2/eva// + if processed_path.startswith("s3://"): + src_prefix = processed_path.rstrip("/") + "/*" + else: + src_prefix = ( + f"s3://{bucket_name}/{processed_path.lstrip('/').rstrip('/')}" + + "/*" + ) + + # Destination is the root local_dir; s5cmd will preserve /... under it + dst = local_dir / episode_hash + lines.append(f'sync "{src_prefix}" "{str(dst)}/"') + + try: + batch_path.write_text("\n".join(lines) + "\n") + + cmd = ["s5cmd", "run", str(batch_path)] + logger.info("Running s5cmd batch (%d lines): %s", len(lines), " ".join(cmd)) + subprocess.run(cmd, check=True) + + finally: + try: + batch_path.unlink(missing_ok=True) + except Exception as e: + logger.warning("Failed to delete batch file %s: %s", batch_path, e) + + @classmethod + def sync_from_filters( + cls, + *, + bucket_name: str, + filters: dict, + local_dir: Path, + ): + """ + Public API: + - resolves episodes from DB using filters + - runs a single aws s3 sync with includes + - downloads into local_dir + + + Returns: + List[(processed_path, episode_hash)] + """ + + # 1) Resolve episodes from DB + filtered_paths = cls._get_filtered_paths(filters) + if not filtered_paths: + logger.warning("No episodes matched filters.") + return [] + + # 2) Logging + logger.info( + f"Syncing S3 datasets with filters {filters} to local directory {local_dir}..." + ) + + # 3) Sync + cls._sync_s3_to_local( + bucket_name=bucket_name, + s3_paths=filtered_paths, + local_dir=local_dir, + ) + + return filtered_paths + +class LocalEpisodeResolver(EpisodeResolver): + """ + Resolves episodes from local Zarr stores, filtering via local metadata. + """ + def __init__( + self, + folder_path: Path, + key_map: dict | None = None, + transform_list: list | None = None, + ): + super().__init__(folder_path, key_map, transform_list) + + @staticmethod + def _local_filters_match(metadata: dict, episode_hash: str, filters: dict) -> bool: + for key, value in filters.items(): + if key == "episode_hash": + if episode_hash != value: + return False + continue + + if key == "robot_name": + meta_value = metadata.get("robot_name", metadata.get("robot_type")) + elif key == "is_deleted": + meta_value = metadata.get("is_deleted", False) + else: + meta_value = metadata.get(key) + + if meta_value is None: + return False + if meta_value != value: + return False + + return True + + @classmethod + def _get_local_filtered_paths(cls, search_path: Path, filters: dict): + if not search_path.is_dir(): + logger.warning("Local path does not exist: %s", search_path) + return [] + + filtered = [] + for p in sorted(search_path.iterdir()): + if not p.is_dir(): + continue + + episode_hash = p.name[:-5] if p.name.endswith(".zarr") else p.name + + try: + store = zarr.open_group(str(p), mode="r") + metadata = dict(store.attrs) + except Exception as e: + logger.warning("Failed to read metadata for %s: %s", p, e) + continue + + if cls._local_filters_match(metadata, episode_hash, filters): + filtered.append((str(p), episode_hash)) + + logger.info("Local filtered paths: %s", filtered) + return filtered + + def resolve( + self, + sync_from_s3=False, + filters={}, + ) -> list[tuple[str, str]]: + """ + Outputs a list of ZarrDatasets with relevant filters from local data. + """ + if sync_from_s3: + logger.warning("LocalEpisodeResolver does not sync from S3; ignoring sync_from_s3=True.") + + filters = dict(filters or {}) + filters.setdefault("is_deleted", False) + + filtered_paths = self._get_local_filtered_paths(self.folder_path, filters) + + valid_hashes = {hashes for _, hashes in filtered_paths} + if not valid_hashes: + raise ValueError( + "No valid collection names from local filtering: " + "filters matched no episodes in the local directory." + ) + + datasets = self._load_zarr_datasets( + search_path=self.folder_path, + valid_folder_names=valid_hashes + ) + + return datasets + + + +class MultiDataset(torch.utils.data.Dataset): + """ + Self wrapping MultiDataset, can wrap zarr or multi dataset. + + """ + def __init__(self, + datasets, + mode="train", + percent=0.1, + valid_ratio=0.2, + **kwargs,): + """ + Args: + datasets (dict): Dictionary mapping unique dataset hashes (str) to dataset objects. Datasets can be individual Zarr datasets or other multi-datasets; mixing different types is supported. + mode (str, optional): Split mode to use (e.g., "train", "valid"). Defaults to "train". + percent (float, optional): Fraction of the dataset to use from each underlying dataset. Defaults to 0.1. + valid_ratio (float, optional): Validation split ratio for datasets that support a train/valid split. + **kwargs: Additional keyword arguments passed to underlying dataset constructors if needed. + """ + self.datasets = datasets + + self.index_map = [] + for dataset_name, dataset in self.datasets.items(): + for local_idx in range(len(dataset)): + self.index_map.append((dataset_name, local_idx)) + + self.train_collections, self.valid_collections = split_dataset_names( + datasets.keys(), valid_ratio=valid_ratio, seed=SEED + ) + + if mode == "train": + chosen = self.train_collections + elif mode == "valid": + chosen = self.valid_collections + elif mode == "total": + chosen = set(datasets.keys()) + elif mode == "percent": + all_names = sorted(datasets.keys()) + rng = random.Random(SEED) + rng.shuffle(all_names) + + n_keep = int(len(all_names) * percent) + if percent > 0.0: + n_keep = max(1, n_keep) + chosen = set(all_names[:n_keep]) + else: + raise ValueError(f"Unknown mode: {mode}") + + datasets = {rid: ds for rid, ds in datasets.items() if rid in chosen} + assert datasets, "No datasets left after applying mode split." + + super().__init__() + + def __len__(self) -> int: + return len(self.index_map) + + + def __getitem__(self, idx): + dataset_name, local_idx = self.index_map[idx] + data = self.datasets[dataset_name][local_idx] + + return data + + @classmethod + def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): + """ + create a MultiDataset from an EpisodeResolver. + + Args: + resolver (EpisodeResolver): The resolver instance to use for loading datasets. + embodiment: The embodiment identifier to use for resolving datasets. + **kwargs: Keyword args forwarded to resolver (e.g., filters, + sync_from_s3) and MultiDataset constructor (e.g., mode, percent, + key_map, valid_ratio). + Returns: + MultiDataset: The constructed multi-dataset. + """ + # TODO add key_map and transform pass to children + + sync_from_s3 = kwargs.pop("sync_from_s3", False) + filters = kwargs.pop("filters", {}) or {} + + resolved = resolver.resolve( + sync_from_s3=sync_from_s3, + filters=filters, + ) + + + return cls(datasets=resolved, **kwargs) + + +class ZarrDataset(torch.utils.data.Dataset): + """ + Base Zarr Dataset object, Just intializes as pass through to read from zarr episode + """ + + def __init__( + self, + Episode_path: Path, + key_map: dict, + transform_list: list | None = None, + ): + """ + Args: + episode_path: just a path to the designated zarr episode + key_map: dict mapping from dataset keys to zarr keys and horizon info, e.g. {"obs/image/front": {"zarr_key": "observations.images.front", "horizon": 4}, ...} + transform_list: list of Transform objects to apply to the data after loading, e.g. for action chunk transformations. Should be in order of application. + """ + self.episode_path = Episode_path + self.metadata = None + self._image_keys = None # Lazy-loaded set of JPEG-encoded keys + self._json_keys = None # Lazy-loaded set of JSON-encoded keys + self._annotations = None + self.init_episode() + + self.key_map = key_map + self.transform = transform_list + super().__init__() + + def init_episode(self): + """ + inits the zarr episode and all the metadata associated, as well as total_frames for len + """ + self.episode_reader = ZarrEpisode(self.episode_path) + self.metadata = self.episode_reader.metadata + self.total_frames = self.metadata["total_frames"] + self.keys_dict = {k: (0, None) for k in self.episode_reader._collect_keys()} + + # Detect JPEG-encoded image keys from metadata + self._image_keys = self._detect_image_keys() + self._json_keys = self._detect_json_keys() + + def _detect_image_keys(self) -> set[str]: + """ + Detect which keys contain JPEG-encoded image data from metadata. + + Returns: + Set of keys containing JPEG data + """ + features = self.metadata.get("features", {}) + return {key for key, info in features.items() if info.get("dtype") == "jpeg"} + + def _detect_json_keys(self) -> set[str]: + """ + Detect keys containing JSON-encoded bytes from metadata. + + Returns: + Set of keys containing JSON payloads. + """ + features = self.metadata.get("features", {}) + return { + key for key, info in features.items() + if info.get("dtype") == "json" + } + + @staticmethod + def _decode_json_entry(value): + if isinstance(value, np.void): + value = value.item() + if isinstance(value, memoryview): + value = value.tobytes() + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, bytes): + return json.loads(value.decode("utf-8")) + if isinstance(value, str): + return json.loads(value) + return value + + def _load_annotations(self) -> list[dict]: + """ + Load and cache decoded language annotations. + + Expected format per entry: + {"text": str, "start_idx": int, "end_idx": int} + """ + if self._annotations is not None: + return self._annotations + + raw = self.episode_reader._store["annotations"][:] + + decoded = [self._decode_json_entry(x) for x in raw] + self._annotations = [d for d in decoded if isinstance(d, dict)] + return self._annotations + + def _annotation_text_for_frame(self, frame_idx: int) -> str: + """ + Resolve language annotation text for a frame from span annotations. + """ + annotations = self._load_annotations() + for ann in annotations: + start_idx = int(ann.get("start_idx", -1)) + end_idx = int(ann.get("end_idx", -1)) + if start_idx <= frame_idx <= end_idx: + return str(ann.get("text", "")) + return "" + + def __len__(self) -> int: + return self.total_frames + + def _pad_sequences(self, data, horizon: int | None) -> dict: + if horizon is None: + return data + + # Note that k is zarr key + for k in data: + if isinstance(data[k], np.ndarray): + seq_len = data[k].shape[0] + if seq_len < horizon: + # Pad by repeating the last frame + pad_len = horizon - seq_len + last_frame = data[k][-1:] # Keep dims: (1, action_dim) + padding = np.repeat(last_frame, pad_len, axis=0) + data[k] = np.concatenate([data[k], padding], axis=0) + + return data + + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: + # Build keys_dict with ranges based on whether action chunking is enabled + data = {} + for k in self.key_map: + zarr_key = self.key_map[k]["zarr_key"] + horizon = self.key_map[k].get("horizon", None) + + if zarr_key == "annotations": + data[k] = self._annotation_text_for_frame(idx) + continue + + if horizon is not None: + end_idx = min(idx + horizon, self.total_frames) + read_interval = (idx, end_idx) + else: + read_interval = (idx, None) + read_dict = {zarr_key: read_interval} + raw_data = self.episode_reader.read(read_dict) + self._pad_sequences(raw_data, horizon) # should be able to pad images + data[k] = raw_data[zarr_key] + + # Decode JPEG-encoded image data and normalize to [0, 1] + # print(f"Print the image_keys: {self._image_keys}") + if zarr_key in self._image_keys: + jpeg_bytes = data[k] + # Decode JPEG bytes to numpy array (H, W, 3) + decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace='RGB') + # data[k] = torch.from_numpy(np.transpose(decoded, (2, 0, 1))).to(torch.float32) / 255.0 + data[k] = np.transpose(decoded, (2, 0, 1)) / 255.0 + elif zarr_key in self._json_keys: + if isinstance(data[k], np.ndarray): + data[k] = [self._decode_json_entry(v) for v in data[k]] + else: + data[k] = self._decode_json_entry(data[k]) + + # Convert all numpy arrays in data to torch tensors + + # TODO add the transform list code here + if self.transform: + for transform in self.transform or []: + data = transform.transform(data) + + for k, v in data.items(): + if isinstance(v, np.ndarray): + data[k] = torch.from_numpy(v).to(torch.float32) + + return data + + + + +class ZarrEpisode: + """ + Lightweight wrapper around a single Zarr episode store. + Designed for efficient PyTorch DataLoader usage with direct store access. + """ + __slots__ = ( + "_path", + "_store", + "metadata", + "keys", + ) + def __init__(self, path: str | Path): + """ + Initialize ZarrEpisode wrapper. + Args: + path: Path to the .zarr episode directory + """ + self._path = Path(path) + self._store = zarr.open_group(str(self._path), mode='r') + self.metadata = dict(self._store.attrs) + self.keys = self.metadata["features"] + + def read(self, keys_with_ranges: dict[str, tuple[int, int | None]]) -> dict[str, np.ndarray]: + """ + Read data for specified keys, each with their own index or range. + Args: + keys_with_ranges: Dictionary mapping keys to (start, end) tuples. + - start: Starting frame index + - end: Ending frame index (exclusive). If None, reads single frame at start. + Returns: + Dictionary mapping keys to numpy arrays + Example: + >>> episode.read({ + ... "obs/image": (0, 10), # Read frames 0-10 + ... "actions": (5, 15), # Read frames 5-15 + ... "rewards": (20, None), # Read single frame at index 20 + ... }) + """ + result = {} + for key, (start, end) in keys_with_ranges.items(): + arr = self._store[key] + if end is not None: + data = arr[start:end] + else: + # Single frame read - use slicing to avoid 0D array issues with VariableLengthBytes + # arr[start:start+1] gives us a 1D array, then [0] extracts the actual object + data = arr[start:start+1][0] + result[key] = data + return result + + def _collect_keys(self) -> list[str]: + """ + Collect all array keys from the store. + Returns: + List of array keys (flat structure with dot-separated names) + """ + if isinstance(self.keys, dict): + return list(self.keys.keys()) + return list(self.keys) + def __len__(self) -> int: + """ + Get total number of frames in the episode. + Returns: + Number of frames + """ + return self.metadata['total_frames'] + def __repr__(self) -> str: + """String representation of the episode.""" + return f"ZarrEpisode(path={self._path}, frames={len(self)})" + +if __name__ == '__main__': + from omegaconf import OmegaConf + import hydra + dataset_cfg_path = '/nethome/paphiwetsa3/flash/projects/EgoVerse/egomimic/hydra_configs/data/test_multi_zarr.yaml' + # Using Hydra to load the dataset config + dataset_cfg = OmegaConf.load(dataset_cfg_path) + datamodule = hydra.utils.instantiate(dataset_cfg) + dl = datamodule.train_dataloader() + batch = next(iter(dl)) + + breakpoint() diff --git a/egomimic/rldb/zarr/zarr_writer.py b/egomimic/rldb/zarr/zarr_writer.py new file mode 100644 index 00000000..fb4aa324 --- /dev/null +++ b/egomimic/rldb/zarr/zarr_writer.py @@ -0,0 +1,670 @@ +""" +ZarrWriter: General-purpose Zarr episode writer. + +This module provides a reusable writer for creating Zarr v3 episode stores +compatible with the ZarrEpisode reader. +""" + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import simplejpeg +import zarr +from zarr.core.dtype import VariableLengthBytes + + +class _IncrementalHandle: + """Context manager handle for incremental frame-by-frame Zarr writing.""" + + def __init__( + self, + writer: "ZarrWriter", + total_frames: int | None, + metadata_override: dict[str, Any] | None, + ): + self._writer = writer + self._total_frames = total_frames + self._metadata_override = metadata_override + self._cursor = 0 + self._capacity = total_frames if total_frames is not None else 0 + self._store: zarr.Group | None = None + self._initialized = False + self._numeric_info: dict[str, dict] = {} + self._image_info: dict[str, dict] = {} + + @property + def frames_written(self) -> int: + return self._cursor + + @property + def _padded_frames(self) -> int: + w = self._writer + if self._total_frames is None: + return self._capacity + padded = self._total_frames + if w.enable_sharding and self._total_frames % w.chunk_timesteps != 0: + padded = ( + (self._total_frames + w.chunk_timesteps - 1) + // w.chunk_timesteps + * w.chunk_timesteps + ) + return padded + + def __enter__(self) -> "_IncrementalHandle": + if self._total_frames is not None: + self._writer.total_frames = self._total_frames + self._writer.episode_path.parent.mkdir(parents=True, exist_ok=True) + self._store = zarr.open( + str(self._writer.episode_path), mode="w", zarr_format=3 + ) + return self + + def _init_arrays( + self, numeric: dict[str, np.ndarray], images: dict[str, np.ndarray] + ) -> None: + """Create pre-allocated zarr arrays from first frame's schema.""" + dynamic_length = self._total_frames is None + if dynamic_length: + self._capacity = max(1, self._writer.chunk_timesteps) + + padded = self._padded_frames + w = self._writer + + for key, arr in numeric.items(): + frame_shape = arr.shape + self._numeric_info[key] = {"shape": frame_shape, "dtype": arr.dtype} + + full_shape = (padded,) + frame_shape + frames_per_chunk = max(1, min(w.chunk_timesteps, padded)) + chunk_shape = (frames_per_chunk,) + frame_shape + + if w.enable_sharding and not dynamic_length: + self._store.create_array( + key, + shape=full_shape, + chunks=chunk_shape, + shards=full_shape, + dtype=arr.dtype, + fill_value=0, + ) + else: + self._store.create_array( + key, + shape=full_shape, + chunks=chunk_shape, + dtype=arr.dtype, + fill_value=0, + ) + + dimension_names = [f"dim_{i}" for i in range(len(frame_shape))] + w._features[key] = { + "dtype": str(arr.dtype), + "shape": list(frame_shape), + "names": dimension_names, + } + + for key, img in images.items(): + if img.ndim != 3 or img.shape[-1] != 3: + raise ValueError( + f"Image '{key}' must have shape (H, W, 3), got {img.shape}" + ) + self._image_info[key] = {"shape": img.shape} + + shape = (padded,) + chunk_shape = (1,) + + if w.enable_sharding and not dynamic_length: + self._store.create_array( + key, + shape=shape, + chunks=chunk_shape, + shards=shape, + dtype=VariableLengthBytes(), + ) + else: + self._store.create_array( + key, + shape=shape, + chunks=chunk_shape, + dtype=VariableLengthBytes(), + ) + + w._features[key] = { + "dtype": "jpeg", + "shape": list(img.shape), + "names": ["height", "width", "channel"], + } + + self._initialized = True + + def _ensure_capacity(self, target_size: int) -> None: + if self._total_frames is not None: + if target_size > self._total_frames: + raise ValueError( + f"Write at frame {target_size} would exceed total_frames={self._total_frames}" + ) + return + + if target_size <= self._capacity: + return + + grown = max( + target_size, + self._capacity * 2 if self._capacity > 0 else self._writer.chunk_timesteps, + 1, + ) + for key, info in self._numeric_info.items(): + self._store[key].resize((grown,) + info["shape"]) + for key in self._image_info: + self._store[key].resize((grown,)) + self._capacity = grown + + def add_frame( + self, + numeric: dict[str, np.ndarray] | None = None, + images: dict[str, np.ndarray] | None = None, + ) -> None: + """ + Write a single frame. + + Args: + numeric: Dict of per-frame numeric arrays, each with shape matching + the feature dimensions (e.g. shape (D,) for a D-dim vector). + images: Dict of per-frame images, each with shape (H, W, 3) uint8. + """ + numeric = numeric or {} + images = images or {} + + if not self._initialized: + self._init_arrays(numeric, images) + + self._ensure_capacity(self._cursor + 1) + + for key, arr in numeric.items(): + self._store[key][self._cursor] = arr + + for key, img in images.items(): + jpeg_bytes = simplejpeg.encode_jpeg( + img, quality=ZarrWriter.JPEG_QUALITY, colorspace="RGB" + ) + self._store[key][self._cursor] = jpeg_bytes + + self._cursor += 1 + + def add_frames( + self, + numeric: dict[str, np.ndarray] | None = None, + images: dict[str, np.ndarray] | None = None, + ) -> None: + """ + Write a batch of frames. + + Args: + numeric: Dict of numeric arrays with shape (B, ...) where B is batch size. + images: Dict of image arrays with shape (B, H, W, 3) uint8. + """ + numeric = numeric or {} + images = images or {} + + batch_size = None + for arr in (*numeric.values(), *images.values()): + if batch_size is None: + batch_size = len(arr) + elif len(arr) != batch_size: + raise ValueError("All arrays in a batch must have the same length") + + if batch_size is None: + return + + if not self._initialized: + first_numeric = {k: v[0] for k, v in numeric.items()} + first_images = {k: v[0] for k, v in images.items()} + self._init_arrays(first_numeric, first_images) + + end = self._cursor + batch_size + self._ensure_capacity(end) + + for key, arr in numeric.items(): + self._store[key][self._cursor:end] = arr + + for key, img_batch in images.items(): + encoded = np.empty((batch_size,), dtype=object) + for i in range(batch_size): + encoded[i] = simplejpeg.encode_jpeg( + img_batch[i], quality=ZarrWriter.JPEG_QUALITY, colorspace="RGB" + ) + self._store[key][self._cursor:end] = encoded + + self._cursor += batch_size + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + if exc_type is not None: + return False + + if self._total_frames is not None and self._cursor != self._total_frames: + raise ValueError( + f"Expected {self._total_frames} frames but wrote {self._cursor}" + ) + + if self._cursor == 0: + raise ValueError("Expected at least one frame, but wrote 0") + + self._writer.total_frames = self._cursor + + if self._total_frames is None: + # Unknown-length mode grows arrays dynamically; trim slack at close. + for key, info in self._numeric_info.items(): + self._store[key].resize((self._cursor,) + info["shape"]) + for key in self._image_info: + self._store[key].resize((self._cursor,)) + else: + # Pad image arrays for sharding alignment (numeric arrays use fill_value=0) + padded = self._padded_frames + if padded > self._total_frames and self._image_info: + pad_len = padded - self._total_frames + for key in self._image_info: + last_jpeg = self._store[key][self._total_frames - 1] + padding = np.empty((pad_len,), dtype=object) + padding[:] = last_jpeg + self._store[key][self._total_frames : padded] = padding + + # Write language annotations + if self._writer.annotations is not None: + self._writer._write_annotations( + self._store, self._writer.annotations + ) + + # Write metadata + metadata = self._writer._build_metadata(self._metadata_override) + self._store.attrs.update(metadata) + + return False + + +class ZarrWriter: + """ + General-purpose writer for Zarr v3 episode stores. + + Creates episodes compatible with the ZarrEpisode reader, handling both + numeric and image data with intelligent chunking and optional sharding. + """ + + JPEG_QUALITY = 85 # Fixed JPEG quality for image compression + + def __init__( + self, + episode_path: str | Path, + embodiment: str = "", + fps: int = 30, + task: str = "", + annotations: list[tuple[str, int, int]] | None = None, + chunk_timesteps: int = 100, + enable_sharding: bool = True, + ): + """ + Initialize ZarrWriter. + + Args: + episode_path: Path to episode .zarr directory. + embodiment: Robot type identifier (e.g., "eva_bimanual"). + fps: Frames per second for playback (default: 30). + task: Task description. + annotations: List of (text, start_idx, end_idx) tuples describing language annotations. + chunk_timesteps: Number of timesteps per chunk for numeric arrays (default: 100). + enable_sharding: Enable Zarr sharding for better cloud performance (default: True). + """ + self.episode_path = Path(episode_path) + + # Store parameters + self.fps = fps + self.embodiment = embodiment + self.task = task + self.annotations = annotations if annotations is not None else [] + self.chunk_timesteps = chunk_timesteps + self.enable_sharding = enable_sharding + + # Track image shapes for metadata + self._features: dict[str, dict[str, Any]] = {} + + def write( + self, + numeric_data: dict[str, np.ndarray] | None = None, + image_data: dict[str, np.ndarray] | None = None, + metadata_override: dict[str, Any] | None = None, + ) -> None: + """ + Write episode data to Zarr store. + + Args: + numeric_data: Dictionary of numeric arrays (state, actions, etc.). + All arrays must have same length along axis 0. + image_data: Dictionary of image arrays with shape (T, H, W, 3). + Images will be JPEG-compressed. + metadata_override: Optional metadata overrides to apply after building metadata. + + Raises: + ValueError: If arrays have inconsistent frame counts. + ValueError: If total_frames was not set and cannot be inferred. + """ + numeric_data = numeric_data or {} + image_data = image_data or {} + + if not numeric_data and not image_data: + raise ValueError("Must provide at least one of numeric_data or image_data") + + # Infer total_frames from data + all_lengths = [] + for key, arr in {**numeric_data, **image_data}.items(): + all_lengths.append(len(arr)) + + if len(set(all_lengths)) > 1: + raise ValueError( + f"Inconsistent frame counts across arrays: {dict(zip(numeric_data.keys() | image_data.keys(), all_lengths))}" + ) + + self.total_frames = all_lengths[0] + + # Calculate padded frame count if sharding is enabled + padded_frames = self.total_frames + if self.enable_sharding and self.total_frames % self.chunk_timesteps != 0: + padded_frames = ((self.total_frames + self.chunk_timesteps - 1) // self.chunk_timesteps) * self.chunk_timesteps + + # Create parent directory + self.episode_path.parent.mkdir(parents=True, exist_ok=True) + + # Open Zarr v3 store + mode = "w" if self.episode_path.exists() else "w" + store = zarr.open(str(self.episode_path), mode=mode, zarr_format=3) + + # Write numeric arrays + for key, arr in numeric_data.items(): + self._write_numeric_array(store, key, arr, padded_frames) + + # Write image arrays + for key, arr in image_data.items(): + self._write_image_array(store, key, arr, padded_frames) + + # Write language annotations if provided + if self.annotations is not None: + self._write_annotations(store, self.annotations) + + # Build and attach metadata + metadata = self._build_metadata(metadata_override) + store.attrs.update(metadata) + + def write_incremental( + self, + total_frames: int | None = None, + metadata_override: dict[str, Any] | None = None, + ) -> _IncrementalHandle: + """ + Begin incremental writing to avoid loading all data into memory. + + Use as a context manager. Array schemas are inferred automatically + from the first add_frame() or add_frames() call. + + Args: + total_frames: Optional total number of frames to write. If omitted, + arrays grow dynamically and are trimmed to final size at close. + metadata_override: Optional metadata overrides. + + Returns: + Context manager with add_frame() and add_frames() methods. + + Example:: + + writer = ZarrWriter(episode_path="ep.zarr", embodiment="eva_bimanual") + with writer.write_incremental() as inc: + for i in range(1000): + inc.add_frame( + numeric={"actions": actions[i]}, + images={"cam_left": images[i]}, + ) + """ + return _IncrementalHandle(self, total_frames, metadata_override) + + def _write_numeric_array(self, store: zarr.Group, key: str, arr: np.ndarray, padded_frames: int) -> None: + """ + Write a numeric array to the Zarr store. + + Args: + store: Zarr group to write to. + key: Array key name. + arr: Numeric array with shape (T, ...). + padded_frames: Target frame count after padding (for sharding alignment). + """ + num_frames = len(arr) + + # Store original shape and dtype before padding for metadata + original_shape = arr.shape[1:] # Shape excluding time dimension + dtype_str = str(arr.dtype) + + # Pad array if needed + if padded_frames > num_frames: + pad_len = padded_frames - num_frames + pad_shape = (pad_len,) + arr.shape[1:] + arr = np.concatenate([arr, np.zeros(pad_shape, dtype=arr.dtype)], axis=0) + + # Use chunk_timesteps for frames per chunk + frames_per_chunk = min(self.chunk_timesteps, padded_frames) + frames_per_chunk = max(1, frames_per_chunk) + + # Chunk shape: (frames, ...) - keep other dimensions intact + chunk_shape = (frames_per_chunk,) + arr.shape[1:] + + # Create array with or without sharding + if self.enable_sharding: + shard_shape = arr.shape + store.create_array( + key, + data=arr, + chunks=chunk_shape, + shards=shard_shape, + ) + else: + store.create_array( + key, + data=arr, + chunks=chunk_shape, + ) + + # Track shape and dtype for metadata + dimension_names = [f"dim_{i}" for i in range(len(original_shape))] + self._features[key] = { + "dtype": dtype_str, + "shape": list(original_shape), + "names": dimension_names, + } + + def _write_image_array(self, store: zarr.Group, key: str, image_arr: np.ndarray, padded_frames: int) -> None: + """ + Write an image array to the Zarr store with JPEG compression. + + Images are always chunked 1 per timestep for efficient random access, + regardless of chunk_timesteps setting. + + Args: + store: Zarr group to write to. + key: Array key name. + image_arr: Image array with shape (T, H, W, 3). + padded_frames: Target frame count after padding (for sharding alignment). + """ + # Validate shape + if image_arr.ndim != 4 or image_arr.shape[-1] != 3: + raise ValueError( + f"Image array '{key}' must have shape (T, H, W, 3), got {image_arr.shape}" + ) + + # Encode each frame as JPEG + num_frames = len(image_arr) + + # Encode to padded_frames length (pad with duplicate of last frame if needed) + encoded = np.empty((padded_frames,), dtype=object) + for i in range(padded_frames): + # Use last frame for padding + frame_idx = min(i, num_frames - 1) + img = image_arr[frame_idx] + jpeg_bytes = simplejpeg.encode_jpeg(img, quality=self.JPEG_QUALITY, colorspace='RGB') + encoded[i] = jpeg_bytes + + # Images are always chunked 1 per timestep, regardless of chunk_timesteps + chunk_shape = (1,) + + # Create array with VariableLengthBytes dtype + if self.enable_sharding: + shard_shape = encoded.shape + store.create_array( + key, + shape=encoded.shape, + chunks=chunk_shape, + shards=shard_shape, + dtype=VariableLengthBytes(), + ) + else: + store.create_array( + key, + shape=encoded.shape, + chunks=chunk_shape, + dtype=VariableLengthBytes(), + ) + + # Assign data after creation (required for VariableLengthBytes) + store[key][:] = encoded + + # Track shape for metadata + self._features[key] = { + "dtype": "jpeg", + "shape": list(image_arr.shape[1:]), # [H, W, 3] + "names": ["height", "width", "channel"], + } + + def _write_annotations( + self, store: zarr.Group, annotations: list[tuple[str, int, int]] + ) -> None: + """ + Write language annotations as JSON-encoded bytes. + + Args: + store: Zarr group to write to. + annotations: List of (text, start_idx, end_idx) tuples. + """ + encoded = np.array( + [ + json.dumps( + {"text": text, "start_idx": int(start_idx), "end_idx": int(end_idx)}, + ensure_ascii=False, + separators=(",", ":"), + ).encode("utf-8") + for text, start_idx, end_idx in annotations + ], + dtype=object, + ) + + n_annotations = len(annotations) + chunk_shape = (max(1, n_annotations),) + if self.enable_sharding: + shard_shape = (n_annotations,) + store.create_array( + "annotations", + shape=encoded.shape, + chunks=chunk_shape, + shards=shard_shape, + dtype=VariableLengthBytes(), + ) + else: + store.create_array( + "annotations", + shape=encoded.shape, + chunks=chunk_shape, + dtype=VariableLengthBytes(), + ) + if n_annotations > 0: + store["annotations"][:] = encoded + + # Track in features + self._features["annotations"] = { + "dtype": "json", + "shape": [n_annotations], + "names": ["json"], + "format": "annotation_v1", + } + + def _build_metadata(self, metadata_override: dict[str, Any] | None = None) -> dict[str, Any]: + """ + Build episode metadata dictionary. + + Args: + metadata_override: Optional overrides to apply. + + Returns: + Metadata dictionary. + """ + metadata = { + "embodiment": self.embodiment, + "total_frames": self.total_frames, + "fps": self.fps, + "task": self.task, + "features": self._features, + } + + # Apply overrides + if metadata_override: + metadata.update(metadata_override) + + return metadata + + @staticmethod + def create_and_write( + episode_path: str | Path, + numeric_data: dict[str, np.ndarray] | None = None, + image_data: dict[str, np.ndarray] | None = None, + embodiment: str = "", + fps: int = 30, + task: str = "", + annotations: list[tuple[str, int, int]] | None = None, + chunk_timesteps: int = 100, + enable_sharding: bool = True, + metadata_override: dict[str, Any] | None = None, + ) -> Path: + """ + Convenience method: create writer and write in one call. + + Args: + episode_path: Path to episode .zarr directory. + numeric_data: Dictionary of numeric arrays (state, actions, etc.). + image_data: Dictionary of image arrays with shape (T, H, W, 3). + embodiment: Robot type identifier. + fps: Frames per second (default: 30). + task: Task description. + annotations: List of (text, start_idx, end_idx) tuples describing language annotations. + chunk_timesteps: Number of timesteps per chunk for numeric arrays (default: 100). + enable_sharding: Enable Zarr sharding (default: True). + metadata_override: Optional metadata overrides. + + Returns: + Path to created episode. + + Raises: + ValueError: If neither numeric_data nor image_data are provided. + """ + # Create writer + writer = ZarrWriter( + episode_path=episode_path, + embodiment=embodiment, + fps=fps, + task=task, + annotations=annotations, + chunk_timesteps=chunk_timesteps, + enable_sharding=enable_sharding, + ) + + # Write data + writer.write( + numeric_data=numeric_data, + image_data=image_data, + metadata_override=metadata_override, + ) + + return writer.episode_path diff --git a/egomimic/scripts/zarr_data_viz.ipynb b/egomimic/scripts/zarr_data_viz.ipynb new file mode 100644 index 00000000..5df1d6f6 --- /dev/null +++ b/egomimic/scripts/zarr_data_viz.ipynb @@ -0,0 +1,344 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "29aeeb40", + "metadata": {}, + "source": [ + "# Zarr batch viz (single episode in MultiDataset)\n", + "\n", + "This notebook builds a `MultiDataset` containing exactly one `ZarrDataset`, loads one batch, visualizes one image with `mediapy`, and prints the rest of the batch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "32d9110f", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import torch\n", + "import mediapy as mpy\n", + "\n", + "from egomimic.rldb.zarr.zarr_dataset_multi import ZarrDataset, MultiDataset\n", + "from egomimic.utils.egomimicUtils import EXTRINSICS\n", + "from egomimic.rldb.zarr.action_chunk_transforms import (\n", + " _matrix_to_xyzypr,\n", + " build_aria_bimanual_transform_list,\n", + " build_eva_bimanual_transform_list,\n", + ")\n", + "\n", + "import numpy as np\n", + "from egomimic.utils.egomimicUtils import INTRINSICS, draw_actions\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4aa1a05", + "metadata": {}, + "outputs": [], + "source": [ + "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n", + "EPISODE_PATH = Path(\"/coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/new/1769460905119.zarr\")\n", + "\n", + "key_map = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"images.right_wrist\": {\"zarr_key\": \"images.right_wrist\"},\n", + " \"images.left_wrist\": {\"zarr_key\": \"images.left_wrist\"},\n", + " \"right.obs_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\"},\n", + " \"right.obs_gripper\": {\"zarr_key\": \"right.gripper\"},\n", + " \"left.obs_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\"},\n", + " \"left.obs_gripper\": {\"zarr_key\": \"left.gripper\"},\n", + " \"right.gripper\": {\"zarr_key\": \"right.gripper\", \"horizon\": 45},\n", + " \"left.gripper\": {\"zarr_key\": \"left.gripper\", \"horizon\": 45},\n", + " \"right.cmd_ee_pose\": {\"zarr_key\": \"right.cmd_ee_pose\", \"horizon\": 45},\n", + " \"left.cmd_ee_pose\": {\"zarr_key\": \"left.cmd_ee_pose\", \"horizon\": 45},\n", + "}\n", + "\n", + "ACTION_CHUNK_LENGTH = 100\n", + "ACTION_STRIDE = 1 # set to 3 for Aria-style anchor sampling\n", + "\n", + "extrinsics = EXTRINSICS[\"x5Dec13_2\"]\n", + "left_extrinsics_pose = _matrix_to_xyzypr(extrinsics[\"left\"][None, :])[0]\n", + "right_extrinsics_pose = _matrix_to_xyzypr(extrinsics[\"right\"][None, :])[0]\n", + "\n", + "transform_list = build_eva_bimanual_transform_list(\n", + " chunk_length=ACTION_CHUNK_LENGTH,\n", + " stride=ACTION_STRIDE,\n", + " left_extra_batch_key={\"left_extrinsics_pose\": left_extrinsics_pose},\n", + " right_extra_batch_key={\"right_extrinsics_pose\": right_extrinsics_pose},\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c7fbf37", + "metadata": {}, + "outputs": [], + "source": [ + "# Build a MultiDataset with exactly one ZarrDataset inside\n", + "single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", + "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n", + "multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n", + "\n", + "print(\"len(single_ds):\", len(single_ds))\n", + "print(\"len(multi_ds):\", len(multi_ds))\n", + "\n", + "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e4626d8", + "metadata": {}, + "outputs": [], + "source": [ + "def viz_batch(batch, image_key, action_key):\n", + " # Image: (C,H,W) -> (H,W,C)\n", + " img = batch[image_key][0].detach().cpu()\n", + " if img.shape[0] in (1, 3):\n", + " img = img.permute(1, 2, 0)\n", + " img_np = img.numpy()\n", + "\n", + " # Make image drawable uint8\n", + " if img_np.dtype != np.uint8:\n", + " if img_np.max() <= 1.0:\n", + " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n", + " else:\n", + " img_np = img_np.clip(0, 255).astype(np.uint8)\n", + " if img_np.shape[-1] == 1:\n", + " img_np = np.repeat(img_np, 3, axis=-1)\n", + "\n", + " intrinsics = INTRINSICS[\"base\"]\n", + " actions = batch[action_key][0].detach().cpu().numpy()\n", + "\n", + " # 14D layout: [L xyz ypr g, R xyz ypr g]\n", + " # 12D layout: [L xyz ypr, R xyz ypr]\n", + " if actions.shape[-1] == 14:\n", + " left_xyz = actions[:, :3]\n", + " right_xyz = actions[:, 7:10]\n", + " elif actions.shape[-1] == 12:\n", + " left_xyz = actions[:, :3]\n", + " right_xyz = actions[:, 6:9]\n", + " else:\n", + " raise ValueError(f\"Unsupported action dim {actions.shape[-1]} for key {action_key}\")\n", + "\n", + " vis = draw_actions(\n", + " img_np.copy(), type=\"xyz\", color=\"Blues\",\n", + " actions=left_xyz, extrinsics=None, intrinsics=intrinsics, arm=\"left\"\n", + " )\n", + " vis = draw_actions(\n", + " vis, type=\"xyz\", color=\"Reds\",\n", + " actions=right_xyz, extrinsics=None, intrinsics=intrinsics, arm=\"right\"\n", + " )\n", + "\n", + " return vis\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8eef8507", + "metadata": {}, + "outputs": [], + "source": [ + "image_key = \"images.front_1\"\n", + "action_key = \"actions_cartesian\"\n", + "\n", + "for batch in loader:\n", + " for k, v in batch.items():\n", + " print(f\"{k}: {tuple(v.shape)}\")\n", + "\n", + " vis = viz_batch(batch, image_key=image_key, action_key=action_key)\n", + " mpy.show_image(vis)\n", + " break\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b94b8c83", + "metadata": {}, + "outputs": [], + "source": [ + "batch[\"actions_cartesian\"][0, 0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b818cad", + "metadata": {}, + "outputs": [], + "source": [ + "batch[\"observations.state.ee_pose\"][0]\n" + ] + }, + { + "cell_type": "markdown", + "id": "1a3382f1", + "metadata": {}, + "source": [ + "## Aria Datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38100d31", + "metadata": {}, + "outputs": [], + "source": [ + "from egomimic.utils.aws.aws_sql import timestamp_ms_to_episode_hash\n", + "timestamp_ms_to_episode_hash(1764285211791)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8693d01c", + "metadata": {}, + "outputs": [], + "source": [ + "# Aria-style chunking example: horizon=30 contiguous frames, sample anchors every 3 -> 10 points, then interpolate to 100.\n", + "EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/proc_zarr/1764285211791.zarr/\")\n", + "\n", + "key_map = {\n", + " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n", + " \"right.obs_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\"},\n", + " \"left.obs_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\"},\n", + " \"right.action_ee_pose\": {\"zarr_key\": \"right.obs_ee_pose\", \"horizon\": 30},\n", + " \"left.action_ee_pose\": {\"zarr_key\": \"left.obs_ee_pose\", \"horizon\": 30},\n", + " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n", + "}\n", + "\n", + "ACTION_CHUNK_LENGTH = 100\n", + "ACTION_STRIDE = 3\n", + "\n", + "transform_list = build_aria_bimanual_transform_list(\n", + " chunk_length=ACTION_CHUNK_LENGTH,\n", + " stride=ACTION_STRIDE,\n", + " left_action_world=\"left.action_ee_pose\",\n", + " right_action_world=\"right.action_ee_pose\",\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1da784ea", + "metadata": {}, + "outputs": [], + "source": [ + "# Build a MultiDataset with exactly one ZarrDataset inside\n", + "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n", + "single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n", + "multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n", + "\n", + "print(\"len(single_ds):\", len(single_ds))\n", + "print(\"len(multi_ds):\", len(multi_ds))\n", + "\n", + "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)\n", + "# batch = next(iter(loader))\n", + "\n", + "# print(\"Batch keys:\", list(batch.keys()))\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "507a1fd6", + "metadata": {}, + "outputs": [], + "source": [ + "batch = next(iter(loader))\n", + "print(\"Batch keys:\", list(batch.keys()))\n", + "print(batch[\"right.action_ee_pose\"][0, 0])\n", + "print(batch[\"left.action_ee_pose\"][0, 0])\n", + "print(batch[\"obs_head_pose\"][0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af65095a", + "metadata": {}, + "outputs": [], + "source": [ + "image_key = \"images.front_1\"\n", + "action_key = \"actions_cartesian\"\n", + "\n", + "ims = []\n", + "for i, batch in enumerate(loader):\n", + " first_img = batch[image_key][0].detach().cpu().permute(1, 2, 0).numpy()\n", + " first_img = (first_img * 255.0).clip(0, 255).astype(np.uint8)\n", + "\n", + " vis = viz_batch(batch, image_key=image_key, action_key=action_key)\n", + " ims.append(vis)\n", + " # mpy.show_image(vis)\n", + "\n", + " # for k, v in batch.items():\n", + " # print(f\"{k}: {tuple(v.shape)}\")\n", + " \n", + " if i > 200:\n", + " break\n", + "\n", + "mpy.show_video(ims, fps=30)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36f120b8", + "metadata": {}, + "outputs": [], + "source": [ + "batch[\"actions_cartesian\"][0, 0]\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bd2e1fc", + "metadata": {}, + "outputs": [], + "source": [ + "batch[\"\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb4a930c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "emimic", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/egomimic/utils/pose_utils.py b/egomimic/utils/pose_utils.py new file mode 100644 index 00000000..395e3709 --- /dev/null +++ b/egomimic/utils/pose_utils.py @@ -0,0 +1,139 @@ +import numpy as np +from scipy.interpolate import interp1d +from scipy.spatial.transform import Rotation as R, Slerp + +def xyzw_to_wxyz(xyzw): + return np.concatenate([xyzw[..., 3:4], xyzw[..., :3]], axis=-1) + +def _interpolate_euler(seq: np.ndarray, chunk_length: int) -> np.ndarray: + """Euler-aware interpolation for a single (T, 6) or (T, 7) sequence.""" + T, D = seq.shape + assert D in (6, 7), f"Expected 6 or 7 dims, got {D}" + + if np.any(seq >= 1e8): + return np.full((chunk_length, D), 1e9) + + old_time = np.linspace(0, 1, T) + new_time = np.linspace(0, 1, chunk_length) + + trans_interp = interp1d(old_time, seq[:, :3], axis=0, kind="linear")(new_time) + + rot_unwrapped = np.unwrap(seq[:, 3:6], axis=0) + rot_interp = interp1d(old_time, rot_unwrapped, axis=0, kind="linear")(new_time) + rot_interp = (rot_interp + np.pi) % (2 * np.pi) - np.pi + + if D == 6: + return np.concatenate([trans_interp, rot_interp], axis=-1) + + grip_interp = interp1d(old_time, seq[:, 6:7], axis=0, kind="linear")(new_time) + return np.concatenate([trans_interp, rot_interp, grip_interp], axis=-1) + + +def _interpolate_linear(seq: np.ndarray, chunk_length: int) -> np.ndarray: + """Simple linear interpolation for arbitrary (T, D) arrays.""" + T, _ = seq.shape + old_time = np.linspace(0, 1, T) + new_time = np.linspace(0, 1, chunk_length) + return interp1d(old_time, seq, axis=0, kind="linear")(new_time) + + +def _interpolate_quat_wxyz(seq: np.ndarray, chunk_length: int) -> np.ndarray: + """Quaternion-aware interpolation for a single (T, 7) sequence.""" + T, D = seq.shape + if D != 7: + raise ValueError(f"Expected 7 dims for xyz+quat(wxyz), got {D}") + + if np.any(seq >= 1e8): + return np.full((chunk_length, D), 1e9) + + old_time = np.linspace(0, 1, T) + new_time = np.linspace(0, 1, chunk_length) + + trans_interp = interp1d(old_time, seq[:, :3], axis=0, kind="linear")(new_time) + quat_wxyz = np.asarray(seq[:, 3:7], dtype=np.float64) + quat_xyzw = quat_wxyz[:, [1, 2, 3, 0]] + + norms = np.linalg.norm(quat_xyzw, axis=1, keepdims=True) + if np.any(norms <= 0): + raise ValueError("Found zero-norm quaternion in input sequence.") + quat_xyzw = quat_xyzw / norms + + # Enforce sign continuity to avoid long-path interpolation. + quat_contiguous = quat_xyzw.copy() + for i in range(1, T): + if np.dot(quat_contiguous[i - 1], quat_contiguous[i]) < 0: + quat_contiguous[i] = -quat_contiguous[i] + + if T == 1: + quat_interp_xyzw = np.repeat(quat_contiguous[:1], chunk_length, axis=0) + else: + slerp = Slerp(old_time, R.from_quat(quat_contiguous)) + quat_interp_xyzw = slerp(new_time).as_quat() + + quat_interp_wxyz = quat_interp_xyzw[:, [3, 0, 1, 2]] + dtype = seq.dtype if np.issubdtype(seq.dtype, np.floating) else np.float64 + return np.concatenate([trans_interp, quat_interp_wxyz], axis=-1).astype( + dtype, copy=False + ) + + +def _matrix_to_xyzypr(mats: np.ndarray) -> np.ndarray: + """ + args: + mats: (B, 4, 4) array of SE3 transformation matrices + returns: + (B, 6) np.array of [[x, y, z, yaw, pitch, roll]] + """ + if mats.ndim != 3 or mats.shape[-2:] != (4, 4): + raise ValueError(f"Expected (B, 4, 4) array, got shape {mats.shape}") + + mats = np.asarray(mats) + dtype = mats.dtype if np.issubdtype(mats.dtype, np.floating) else np.float64 + + xyz = mats[:, :3, 3] + ypr = R.from_matrix(mats[:, :3, :3]).as_euler("ZYX", degrees=False) + + return np.concatenate([xyz, ypr], axis=-1).astype(dtype, copy=False) + + +def _matrix_to_xyzwxyz(mats: np.ndarray) -> np.ndarray: + """ + args: + mats: (B, 4, 4) array of SE3 transformation matrices + returns: + (B, 7) np.array of [[x, y, z, qw, qx, qy, qz]] + """ + if mats.ndim != 3 or mats.shape[-2:] != (4, 4): + raise ValueError(f"Expected (B, 4, 4) array, got shape {mats.shape}") + + mats = np.asarray(mats) + dtype = mats.dtype if np.issubdtype(mats.dtype, np.floating) else np.float64 + + xyz = mats[:, :3, 3] + quat_xyzw = R.from_matrix(mats[:, :3, :3]).as_quat() + quat_wxyz = quat_xyzw[:, [3, 0, 1, 2]] + + return np.concatenate([xyz, quat_wxyz], axis=-1).astype(dtype, copy=False) + + +def _xyzwxyz_to_matrix(xyzwxyz: np.ndarray) -> np.ndarray: + """ + args: + xyzwxyz: (B, 7) np.array of [[x, y, z, qw, qx, qy, qz]] + returns: + (B, 4, 4) array of SE3 transformation matrices + """ + if xyzwxyz.ndim != 2 or xyzwxyz.shape[-1] != 7: + raise ValueError(f"Expected (B, 7) array, got shape {xyzwxyz.shape}") + + B = xyzwxyz.shape[0] + dtype = ( + xyzwxyz.dtype if np.issubdtype(xyzwxyz.dtype, np.floating) else np.float64 + ) + + mats = np.broadcast_to(np.eye(4, dtype=dtype), (B, 4, 4)).copy() + quat_xyzw = xyzwxyz[:, [4, 5, 6, 3]] + mats[:, :3, :3] = R.from_quat(quat_xyzw).as_matrix() + mats[:, :3, 3] = xyzwxyz[:, :3] + + return mats diff --git a/requirements.txt b/requirements.txt index bc05cce1..2041aa78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,4 +50,6 @@ pyarrow simplejpeg prettytable datasets==4.0.0 -s5cmd \ No newline at end of file +s5cmd +mediapy +pytest \ No newline at end of file