Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 38 additions & 191 deletions egomimic/rldb/embodiment/eva.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import cv2
import numpy as np
from scipy.spatial.transform import Rotation as R
from typing import Literal

from egomimic.rldb.embodiment.embodiment import Embodiment
from egomimic.rldb.zarr.action_chunk_transforms import (
Expand All @@ -11,23 +9,26 @@
DeleteKeys,
InterpolateLinear,
InterpolatePose,
PoseCoordinateFrameTransform,
NumpyToTensor,
PoseCoordinateFrameTransform,
Transform,
XYZWXYZ_to_XYZYPR,
)
from egomimic.utils.egomimicUtils import (
EXTRINSICS,
INTRINSICS,
cam_frame_to_cam_pixels,
draw_actions,
)
from egomimic.utils.pose_utils import (
_matrix_to_xyzwxyz,
)
from egomimic.utils.type_utils import _to_numpy
from egomimic.utils.viz_utils import (
_viz_axes,
_viz_traj,
)


class Eva(Embodiment):
VIZ_INTRINSICS_KEY = "base"
VIZ_IMAGE_KEY = "observations.images.front_img_1"

@staticmethod
Expand All @@ -40,32 +41,38 @@ def viz_transformed_batch(cls, batch, mode=""):
Visualize one transformed EVA batch sample.

Modes:
- palm_traj: draw left/right palm trajectories from actions_cartesian.
- palm_axes: draw local xyz axes at each palm anchor using ypr.
- traj: draw left/right trajectories from actions_cartesian.
- axes: draw local xyz axes at each anchor using ypr.
"""
image_key = cls.VIZ_IMAGE_KEY
action_key = "actions_cartesian"
intrinsics_key = "base"
mode = (mode or "palm_traj").lower()
mode = (mode or "traj").lower()

images = _to_numpy(batch[image_key][0])
actions = _to_numpy(batch[action_key][0])

return cls.viz(
images=images, actions=actions, mode=mode, intrinsics_key=intrinsics_key
)

if mode == "palm_traj":
return _viz_batch_palm_traj(
batch=batch,
image_key=image_key,
action_key=action_key,
@classmethod
def viz(cls, images, actions, mode=Literal["traj", "axes"], intrinsics_key=None):
intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY
if mode == "traj":
return _viz_traj(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
)
if mode == "palm_axes":
return _viz_batch_palm_axes(
batch=batch,
image_key=image_key,
action_key=action_key,
if mode == "axes":
return _viz_axes(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
)

raise ValueError(
f"Unsupported mode '{mode}'. Expected one of: "
f"('palm_traj', 'palm_axes', 'keypoints')."
f"Unsupported mode '{mode}'. Expected one of: " f"('traj', 'axes')."
)

@classmethod
Expand Down Expand Up @@ -149,46 +156,48 @@ def _build_eva_bimanual_transform_list(
right_extrinsics_pose = _matrix_to_xyzwxyz(extrinsics["right"][None, :])[0]
left_extra_batch_key = {"left_extrinsics_pose": left_extrinsics_pose}
right_extra_batch_key = {"right_extrinsics_pose": right_extrinsics_pose}

mode = "xyzwxyz" if is_quat else "xyzypr"
transform_list = [
ActionChunkCoordinateFrameTransform(
target_world=left_target_world,
chunk_world=left_cmd_world,
transformed_key_name=left_cmd_camframe,
extra_batch_key=left_extra_batch_key,
is_quat=is_quat,
mode=mode,
),
ActionChunkCoordinateFrameTransform(
target_world=right_target_world,
chunk_world=right_cmd_world,
transformed_key_name=right_cmd_camframe,
extra_batch_key=right_extra_batch_key,
is_quat=is_quat,
mode=mode,
),
PoseCoordinateFrameTransform(
target_world=left_target_world,
pose_world=left_obs_pose,
transformed_key_name=left_obs_pose,
is_quat=is_quat,
mode=mode,
),
PoseCoordinateFrameTransform(
target_world=right_target_world,
pose_world=right_obs_pose,
transformed_key_name=right_obs_pose,
is_quat=is_quat,
mode=mode,
),
InterpolatePose(
new_chunk_length=chunk_length,
action_key=left_cmd_camframe,
output_action_key=left_cmd_camframe,
stride=stride,
is_quat=is_quat,
mode=mode,
),
InterpolatePose(
new_chunk_length=chunk_length,
action_key=right_cmd_camframe,
output_action_key=right_cmd_camframe,
stride=stride,
is_quat=is_quat,
mode=mode,
),
InterpolateLinear(
new_chunk_length=chunk_length,
Expand Down Expand Up @@ -255,165 +264,3 @@ def _build_eva_bimanual_transform_list(
]
)
return transform_list


def _to_numpy(arr):
if hasattr(arr, "detach"):
arr = arr.detach()
if hasattr(arr, "cpu"):
arr = arr.cpu()
if hasattr(arr, "numpy"):
return arr.numpy()
return np.asarray(arr)


def _split_action_pose(actions):
# 14D layout: [L xyz ypr g, R xyz ypr g]
# 12D layout: [L xyz ypr, R xyz ypr]
if actions.shape[-1] == 14:
left_xyz = actions[:, :3]
left_ypr = actions[:, 3:6]
right_xyz = actions[:, 7:10]
right_ypr = actions[:, 10:13]
elif actions.shape[-1] == 12:
left_xyz = actions[:, :3]
left_ypr = actions[:, 3:6]
right_xyz = actions[:, 6:9]
right_ypr = actions[:, 9:12]
else:
raise ValueError(f"Unsupported action dim {actions.shape[-1]}")
return left_xyz, left_ypr, right_xyz, right_ypr


def _prepare_viz_image(batch, image_key):
img = _to_numpy(batch[image_key][0])
if img.ndim == 3 and img.shape[0] in (1, 3):
img = np.transpose(img, (1, 2, 0))

if img.dtype != np.uint8:
if img.max() <= 1.0:
img = (img * 255.0).clip(0, 255).astype(np.uint8)
else:
img = img.clip(0, 255).astype(np.uint8)

if img.ndim == 2:
img = np.repeat(img[:, :, None], 3, axis=-1)
elif img.shape[-1] == 1:
img = np.repeat(img, 3, axis=-1)

return img


def _viz_batch_palm_traj(batch, image_key, action_key, intrinsics_key):
img_np = _prepare_viz_image(batch, image_key)
intrinsics = INTRINSICS[intrinsics_key]
actions = _to_numpy(batch[action_key][0])
left_xyz, _, right_xyz, _ = _split_action_pose(actions)

vis = draw_actions(
img_np.copy(),
type="xyz",
color="Blues",
actions=left_xyz,
extrinsics=None,
intrinsics=intrinsics,
arm="left",
)
vis = draw_actions(
vis,
type="xyz",
color="Reds",
actions=right_xyz,
extrinsics=None,
intrinsics=intrinsics,
arm="right",
)
return vis


def _viz_batch_palm_axes(batch, image_key, action_key, intrinsics_key, axis_len_m=0.04):
img_np = _prepare_viz_image(batch, image_key)
intrinsics = INTRINSICS[intrinsics_key]
actions = _to_numpy(batch[action_key][0])
left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions)
vis = img_np.copy()

def _draw_axis_color_legend(frame):
_, w = frame.shape[:2]
x_right = w - 12
y_start = 14
y_step = 12
line_len = 24
axis_legend = [
("x", (255, 0, 0)),
("y", (0, 255, 0)),
("z", (0, 0, 255)),
]
for i, (name, color) in enumerate(axis_legend):
y = y_start + i * y_step
x0 = x_right - line_len
x1 = x_right
cv2.line(frame, (x0, y), (x1, y), color, 3)
cv2.putText(
frame,
name,
(x0 - 12, y + 4),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
color,
1,
cv2.LINE_AA,
)
return frame

def _draw_rotation_at_palm(frame, xyz_seq, ypr_seq, label, anchor_color):
if len(xyz_seq) == 0 or len(ypr_seq) == 0:
return frame

palm_xyz = xyz_seq[0]
palm_ypr = ypr_seq[0]
rot = R.from_euler("ZYX", palm_ypr, degrees=False).as_matrix()

axis_points_cam = np.vstack(
[
palm_xyz,
palm_xyz + rot[:, 0] * axis_len_m,
palm_xyz + rot[:, 1] * axis_len_m,
palm_xyz + rot[:, 2] * axis_len_m,
]
)

px = cam_frame_to_cam_pixels(axis_points_cam, intrinsics)[:, :2]
if not np.isfinite(px).all():
return frame
pts = np.round(px).astype(np.int32)

h, w = frame.shape[:2]
x0, y0 = pts[0]
if not (0 <= x0 < w and 0 <= y0 < h):
return frame

cv2.circle(frame, (x0, y0), 4, anchor_color, -1)
axis_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
for i, color in enumerate(axis_colors, start=1):
x1, y1 = pts[i]
if 0 <= x1 < w and 0 <= y1 < h:
cv2.line(frame, (x0, y0), (x1, y1), color, 2)
cv2.circle(frame, (x1, y1), 2, color, -1)

cv2.putText(
frame,
label,
(x0 + 6, max(12, y0 - 8)),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
anchor_color,
1,
cv2.LINE_AA,
)
return frame

vis = _draw_rotation_at_palm(vis, left_xyz, left_ypr, "L rot", (255, 180, 80))
vis = _draw_rotation_at_palm(vis, right_xyz, right_ypr, "R rot", (80, 180, 255))
vis = _draw_axis_color_legend(vis)
return vis