diff --git a/egomimic/rldb/embodiment/eva.py b/egomimic/rldb/embodiment/eva.py index d77e7407..369b546c 100644 --- a/egomimic/rldb/embodiment/eva.py +++ b/egomimic/rldb/embodiment/eva.py @@ -1,8 +1,6 @@ from __future__ import annotations -import cv2 -import numpy as np -from scipy.spatial.transform import Rotation as R +from typing import Literal from egomimic.rldb.embodiment.embodiment import Embodiment from egomimic.rldb.zarr.action_chunk_transforms import ( @@ -11,23 +9,26 @@ DeleteKeys, InterpolateLinear, InterpolatePose, - PoseCoordinateFrameTransform, NumpyToTensor, + PoseCoordinateFrameTransform, Transform, XYZWXYZ_to_XYZYPR, ) from egomimic.utils.egomimicUtils import ( EXTRINSICS, - INTRINSICS, - cam_frame_to_cam_pixels, - draw_actions, ) from egomimic.utils.pose_utils import ( _matrix_to_xyzwxyz, ) +from egomimic.utils.type_utils import _to_numpy +from egomimic.utils.viz_utils import ( + _viz_axes, + _viz_traj, +) class Eva(Embodiment): + VIZ_INTRINSICS_KEY = "base" VIZ_IMAGE_KEY = "observations.images.front_img_1" @staticmethod @@ -40,32 +41,38 @@ def viz_transformed_batch(cls, batch, mode=""): Visualize one transformed EVA batch sample. Modes: - - palm_traj: draw left/right palm trajectories from actions_cartesian. - - palm_axes: draw local xyz axes at each palm anchor using ypr. + - traj: draw left/right trajectories from actions_cartesian. + - axes: draw local xyz axes at each anchor using ypr. """ image_key = cls.VIZ_IMAGE_KEY action_key = "actions_cartesian" intrinsics_key = "base" - mode = (mode or "palm_traj").lower() + mode = (mode or "traj").lower() + + images = _to_numpy(batch[image_key][0]) + actions = _to_numpy(batch[action_key][0]) + + return cls.viz( + images=images, actions=actions, mode=mode, intrinsics_key=intrinsics_key + ) - if mode == "palm_traj": - return _viz_batch_palm_traj( - batch=batch, - image_key=image_key, - action_key=action_key, + @classmethod + def viz(cls, images, actions, mode=Literal["traj", "axes"], intrinsics_key=None): + intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY + if mode == "traj": + return _viz_traj( + images=images, + actions=actions, intrinsics_key=intrinsics_key, ) - if mode == "palm_axes": - return _viz_batch_palm_axes( - batch=batch, - image_key=image_key, - action_key=action_key, + if mode == "axes": + return _viz_axes( + images=images, + actions=actions, intrinsics_key=intrinsics_key, ) - raise ValueError( - f"Unsupported mode '{mode}'. Expected one of: " - f"('palm_traj', 'palm_axes', 'keypoints')." + f"Unsupported mode '{mode}'. Expected one of: " f"('traj', 'axes')." ) @classmethod @@ -149,46 +156,48 @@ def _build_eva_bimanual_transform_list( right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["right"][None, :])[0] left_extra_batch_key = {"left_extrinsics_pose": left_extrinsics_pose} right_extra_batch_key = {"right_extrinsics_pose": right_extrinsics_pose} + + mode = "xyzwxyz" if is_quat else "xyzypr" transform_list = [ ActionChunkCoordinateFrameTransform( target_world=left_target_world, chunk_world=left_cmd_world, transformed_key_name=left_cmd_camframe, extra_batch_key=left_extra_batch_key, - is_quat=is_quat, + mode=mode, ), ActionChunkCoordinateFrameTransform( target_world=right_target_world, chunk_world=right_cmd_world, transformed_key_name=right_cmd_camframe, extra_batch_key=right_extra_batch_key, - is_quat=is_quat, + mode=mode, ), PoseCoordinateFrameTransform( target_world=left_target_world, pose_world=left_obs_pose, transformed_key_name=left_obs_pose, - is_quat=is_quat, + mode=mode, ), PoseCoordinateFrameTransform( target_world=right_target_world, pose_world=right_obs_pose, transformed_key_name=right_obs_pose, - is_quat=is_quat, + mode=mode, ), InterpolatePose( new_chunk_length=chunk_length, action_key=left_cmd_camframe, output_action_key=left_cmd_camframe, stride=stride, - is_quat=is_quat, + mode=mode, ), InterpolatePose( new_chunk_length=chunk_length, action_key=right_cmd_camframe, output_action_key=right_cmd_camframe, stride=stride, - is_quat=is_quat, + mode=mode, ), InterpolateLinear( new_chunk_length=chunk_length, @@ -255,165 +264,3 @@ def _build_eva_bimanual_transform_list( ] ) return transform_list - - -def _to_numpy(arr): - if hasattr(arr, "detach"): - arr = arr.detach() - if hasattr(arr, "cpu"): - arr = arr.cpu() - if hasattr(arr, "numpy"): - return arr.numpy() - return np.asarray(arr) - - -def _split_action_pose(actions): - # 14D layout: [L xyz ypr g, R xyz ypr g] - # 12D layout: [L xyz ypr, R xyz ypr] - if actions.shape[-1] == 14: - left_xyz = actions[:, :3] - left_ypr = actions[:, 3:6] - right_xyz = actions[:, 7:10] - right_ypr = actions[:, 10:13] - elif actions.shape[-1] == 12: - left_xyz = actions[:, :3] - left_ypr = actions[:, 3:6] - right_xyz = actions[:, 6:9] - right_ypr = actions[:, 9:12] - else: - raise ValueError(f"Unsupported action dim {actions.shape[-1]}") - return left_xyz, left_ypr, right_xyz, right_ypr - - -def _prepare_viz_image(batch, image_key): - img = _to_numpy(batch[image_key][0]) - if img.ndim == 3 and img.shape[0] in (1, 3): - img = np.transpose(img, (1, 2, 0)) - - if img.dtype != np.uint8: - if img.max() <= 1.0: - img = (img * 255.0).clip(0, 255).astype(np.uint8) - else: - img = img.clip(0, 255).astype(np.uint8) - - if img.ndim == 2: - img = np.repeat(img[:, :, None], 3, axis=-1) - elif img.shape[-1] == 1: - img = np.repeat(img, 3, axis=-1) - - return img - - -def _viz_batch_palm_traj(batch, image_key, action_key, intrinsics_key): - img_np = _prepare_viz_image(batch, image_key) - intrinsics = INTRINSICS[intrinsics_key] - actions = _to_numpy(batch[action_key][0]) - left_xyz, _, right_xyz, _ = _split_action_pose(actions) - - vis = draw_actions( - img_np.copy(), - type="xyz", - color="Blues", - actions=left_xyz, - extrinsics=None, - intrinsics=intrinsics, - arm="left", - ) - vis = draw_actions( - vis, - type="xyz", - color="Reds", - actions=right_xyz, - extrinsics=None, - intrinsics=intrinsics, - arm="right", - ) - return vis - - -def _viz_batch_palm_axes(batch, image_key, action_key, intrinsics_key, axis_len_m=0.04): - img_np = _prepare_viz_image(batch, image_key) - intrinsics = INTRINSICS[intrinsics_key] - actions = _to_numpy(batch[action_key][0]) - left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions) - vis = img_np.copy() - - def _draw_axis_color_legend(frame): - _, w = frame.shape[:2] - x_right = w - 12 - y_start = 14 - y_step = 12 - line_len = 24 - axis_legend = [ - ("x", (255, 0, 0)), - ("y", (0, 255, 0)), - ("z", (0, 0, 255)), - ] - for i, (name, color) in enumerate(axis_legend): - y = y_start + i * y_step - x0 = x_right - line_len - x1 = x_right - cv2.line(frame, (x0, y), (x1, y), color, 3) - cv2.putText( - frame, - name, - (x0 - 12, y + 4), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - color, - 1, - cv2.LINE_AA, - ) - return frame - - def _draw_rotation_at_palm(frame, xyz_seq, ypr_seq, label, anchor_color): - if len(xyz_seq) == 0 or len(ypr_seq) == 0: - return frame - - palm_xyz = xyz_seq[0] - palm_ypr = ypr_seq[0] - rot = R.from_euler("ZYX", palm_ypr, degrees=False).as_matrix() - - axis_points_cam = np.vstack( - [ - palm_xyz, - palm_xyz + rot[:, 0] * axis_len_m, - palm_xyz + rot[:, 1] * axis_len_m, - palm_xyz + rot[:, 2] * axis_len_m, - ] - ) - - px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2] - if not np.isfinite(px).all(): - return frame - pts = np.round(px).astype(np.int32) - - h, w = frame.shape[:2] - x0, y0 = pts[0] - if not (0 <= x0 < w and 0 <= y0 < h): - return frame - - cv2.circle(frame, (x0, y0), 4, anchor_color, -1) - axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)] - for i, color in enumerate(axis_colors, start=1): - x1, y1 = pts[i] - if 0 <= x1 < w and 0 <= y1 < h: - cv2.line(frame, (x0, y0), (x1, y1), color, 2) - cv2.circle(frame, (x1, y1), 2, color, -1) - - cv2.putText( - frame, - label, - (x0 + 6, max(12, y0 - 8)), - cv2.FONT_HERSHEY_SIMPLEX, - 0.4, - anchor_color, - 1, - cv2.LINE_AA, - ) - return frame - - vis = _draw_rotation_at_palm(vis, left_xyz, left_ypr, "L rot", (255, 180, 80)) - vis = _draw_rotation_at_palm(vis, right_xyz, right_ypr, "R rot", (80, 180, 255)) - vis = _draw_axis_color_legend(vis) - return vis