diff --git a/egomimic/algo/hpt.py b/egomimic/algo/hpt.py
index 5de7e0b7..6ae47832 100644
--- a/egomimic/algo/hpt.py
+++ b/egomimic/algo/hpt.py
@@ -19,8 +19,6 @@
STD_SCALE,
EinOpsRearrange,
download_from_huggingface,
- draw_actions,
- draw_rotation_text,
frechet_gaussian_over_time,
get_sinusoid_encoding_table,
reverse_kl_from_samples,
@@ -800,6 +798,7 @@ def __init__(
encoder_specs: dict = None,
domains: list = None,
auxiliary_ac_keys: dict = {},
+ viz_func: dict = None,
# ---------------------------
# Pretrained
# ---------------------------
@@ -812,6 +811,7 @@ def __init__(
):
self.nets = nn.ModuleDict()
self.data_schematic = data_schematic
+ self.viz_func = viz_func
self.camera_transforms = camera_transforms
self.train_image_augs = train_image_augs
@@ -1250,69 +1250,12 @@ def visualize_preds(self, predictions, batch):
Returns:
ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top
"""
-
+ if self.viz_func is None:
+ raise ValueError("viz_func is not set")
embodiment_id = batch["embodiment"][0].item()
embodiment_name = get_embodiment(embodiment_id).lower()
- ac_key = self.ac_keys[embodiment_id]
-
- viz_img_key = self.data_schematic.viz_img_key()[embodiment_id]
- ims = (batch[viz_img_key].cpu().numpy().transpose((0, 2, 3, 1)) * 255).astype(
- np.uint8
- )
- for key in batch:
- if f"{embodiment_name}_{key}" in predictions:
- preds = predictions[f"{embodiment_name}_{key}"]
- gt = batch[key]
-
- if self.is_6dof and ac_key == "actions_cartesian":
- gt, gt_rot = self._extract_xyz(gt)
- preds, preds_rot = self._extract_xyz(preds)
-
- for b in range(ims.shape[0]):
- if preds.shape[-1] == 7 or preds.shape[-1] == 14:
- ac_type = "joints"
- elif preds.shape[-1] == 3 or preds.shape[-1] == 6:
- ac_type = "xyz"
- else:
- raise ValueError(
- f"Unknown action type with shape {preds.shape}"
- )
- # Determine arm from embodiment name, not action shape
- if "bimanual" in embodiment_name:
- arm = "both"
- elif "left" in embodiment_name:
- arm = "left"
- elif "right" in embodiment_name:
- arm = "right"
- else:
- raise ValueError(f"Unknown embodiment name: {embodiment_name}")
- ims[b] = draw_actions(
- ims[b],
- ac_type,
- "Purples",
- preds[b].cpu().numpy(),
- self.camera_transforms[embodiment_name].extrinsics,
- self.camera_transforms[embodiment_name].intrinsics,
- arm=arm,
- kinematics_solver=self.kinematics_solver,
- )
- ims[b] = draw_actions(
- ims[b],
- ac_type,
- "Greens",
- gt[b].cpu().numpy(),
- self.camera_transforms[embodiment_name].extrinsics,
- self.camera_transforms[embodiment_name].intrinsics,
- arm=arm,
- kinematics_solver=self.kinematics_solver,
- )
-
- if self.is_6dof and ac_key == "actions_cartesian":
- ims[b] = draw_rotation_text(
- ims[b], gt_rot[b][0], preds_rot[b][0], position=(340, 20)
- )
- return ims
+ return self.viz_func[embodiment_name](predictions, batch)
@override
def compute_losses(self, predictions, batch):
diff --git a/egomimic/hydra_configs/data/eva_human_cotrain.yaml b/egomimic/hydra_configs/data/eva_human_cotrain.yaml
index cabf760d..ea70acc9 100644
--- a/egomimic/hydra_configs/data/eva_human_cotrain.yaml
+++ b/egomimic/hydra_configs/data/eva_human_cotrain.yaml
@@ -19,8 +19,10 @@ train_datasets:
folder_path: /coc/flash7/scratch/egoverseDebugDatasets/aria
key_map:
_target_: egomimic.rldb.embodiment.human.Aria.get_keymap
+ mode: cartesian
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
+ mode: cartesian
filters:
episode_hash: "2025-09-20-17-47-54-000000"
mode: total
@@ -44,8 +46,10 @@ valid_datasets:
folder_path: /coc/flash7/scratch/egoverseDebugDatasets/aria
key_map:
_target_: egomimic.rldb.embodiment.human.Aria.get_keymap
+ mode: cartesian
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
+ mode: cartesian
filters:
episode_hash: "2025-09-20-17-47-54-000000"
mode: total
@@ -62,4 +66,4 @@ valid_dataloader_params:
num_workers: 10
aria_bimanual:
batch_size: 32
- num_workers: 10
\ No newline at end of file
+ num_workers: 10
diff --git a/egomimic/hydra_configs/train.yaml b/egomimic/hydra_configs/train.yaml
index 828a21b6..c4c6e69b 100644
--- a/egomimic/hydra_configs/train.yaml
+++ b/egomimic/hydra_configs/train.yaml
@@ -16,7 +16,7 @@ train: true
eval: false
eval_class:
- _target_ : egomimic.scripts.evaluation.Eve
+ _target_: egomimic.scripts.evaluation.Eve
mode: real
arm: both
eval_path: "./logs/eval/${name}_${now:%Y-%m-%d_%H-%M-%S}"
@@ -93,10 +93,3 @@ data_schematic: # Dynamically fill in these shapes from the dataset
embodiment:
key_type: metadata_keys
lerobot_key: metadata.embodiment
- viz_img_key:
- eva_bimanual:
- front_img_1
- aria_bimanual:
- front_img_1
- mecka_bimanual:
- front_img_1
diff --git a/egomimic/hydra_configs/train_zarr.yaml b/egomimic/hydra_configs/train_zarr.yaml
index 0cc53a23..fe26c8a4 100644
--- a/egomimic/hydra_configs/train_zarr.yaml
+++ b/egomimic/hydra_configs/train_zarr.yaml
@@ -1,10 +1,11 @@
defaults:
- - model: hpt_bc_flow_eva
+ - model: hpt_cotrain_flow_shared_head
+ - visualization: eva_cartesian_aria_cartesian
- paths: default
- trainer: ddp
- debug: null
- logger: wandb
- - data: eva
+ - data: eva_human_cotrain
- callbacks: checkpoints
- override hydra/launcher: submitit
- _self_
@@ -32,7 +33,6 @@ launch_params:
gpus_per_node: 1
nodes: 1
-
data_schematic: # Dynamically fill in these shapes from the dataset
_target_: egomimic.rldb.zarr.utils.DataSchematic
norm_mode: quantile
@@ -101,14 +101,5 @@ data_schematic: # Dynamically fill in these shapes from the dataset
embodiment:
key_type: metadata_keys
zarr_key: metadata.embodiment
- viz_img_key:
- eva_bimanual:
- front_img_1
- aria_bimanual:
- front_img_1
- mecka_bimanual:
- front_img_1
- scale_bimanual:
- front_img_1
seed: 42
diff --git a/egomimic/hydra_configs/trainer/debug.yaml b/egomimic/hydra_configs/trainer/debug.yaml
index 966827ce..e3a9a1a5 100644
--- a/egomimic/hydra_configs/trainer/debug.yaml
+++ b/egomimic/hydra_configs/trainer/debug.yaml
@@ -3,8 +3,8 @@ defaults:
strategy: ddp_find_unused_parameters_true
limit_train_batches: 5
-limit_val_batches: 3
+limit_val_batches: 20
check_val_every_n_epoch: 2
profiler: simple
max_epochs: 4
-min_epochs: 4
\ No newline at end of file
+min_epochs: 4
diff --git a/egomimic/hydra_configs/visualization/eva_cartesian_aria_cartesian.yaml b/egomimic/hydra_configs/visualization/eva_cartesian_aria_cartesian.yaml
new file mode 100644
index 00000000..d4ed7a7a
--- /dev/null
+++ b/egomimic/hydra_configs/visualization/eva_cartesian_aria_cartesian.yaml
@@ -0,0 +1,10 @@
+eva_bimanual:
+ _target_: egomimic.rldb.embodiment.eva.Eva.viz_cartesian_gt_preds
+ _partial_: true
+ image_key: front_img_1
+ action_key: actions_cartesian
+aria_bimanual:
+ _target_: egomimic.rldb.embodiment.human.Human.viz_cartesian_gt_preds
+ _partial_: true
+ image_key: front_img_1
+ action_key: actions_cartesian
diff --git a/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints.yaml b/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints.yaml
new file mode 100644
index 00000000..8c4d1c91
--- /dev/null
+++ b/egomimic/hydra_configs/visualization/eva_cartesian_aria_keypoints.yaml
@@ -0,0 +1,14 @@
+eva_bimanual:
+ action_keys: actions_cartesian
+ viz_function:
+ _target_: egomimic.rldb.embodiment.eva.Eva.viz
+ _partial_: true
+ mode: traj
+ intrinsics_key: base_half
+aria_bimanual:
+ action_keys: actions_cartesian
+ viz_function:
+ _target_: egomimic.rldb.embodiment.human.Aria.viz
+ _partial_: true
+ mode: keypoints
+ intrinsics_key: base_half
diff --git a/egomimic/rldb/embodiment/embodiment.py b/egomimic/rldb/embodiment/embodiment.py
index 24fab039..462b3357 100644
--- a/egomimic/rldb/embodiment/embodiment.py
+++ b/egomimic/rldb/embodiment/embodiment.py
@@ -1,7 +1,10 @@
from abc import ABC
from enum import Enum
+import numpy as np
+
from egomimic.rldb.zarr.action_chunk_transforms import Transform
+from egomimic.utils.type_utils import _to_numpy
class EMBODIMENT(Enum):
@@ -53,3 +56,25 @@ def viz_transformed_batch(batch):
def get_keymap():
"""Returns a dictionary mapping from the raw keys in the dataset to the canonical keys used by the model."""
raise NotImplementedError
+
+ @classmethod
+ def viz_cartesian_gt_preds(cls, predictions, batch, image_key, action_key):
+ embodiment_id = batch["embodiment"][0].item()
+ embodiment_name = get_embodiment(embodiment_id).lower()
+
+ images = batch[image_key]
+ actions = batch[action_key]
+ pred_actions = predictions[f"{embodiment_name}_{action_key}"]
+ ims_list = []
+ images = _to_numpy(images)
+ actions = _to_numpy(actions)
+ pred_actions = _to_numpy(pred_actions)
+ for i in range(images.shape[0]):
+ image = images[i]
+ action = actions[i]
+ pred_action = pred_actions[i]
+ ims = cls.viz(image, action, mode="traj", color="Reds")
+ ims = cls.viz(ims, pred_action, mode="traj", color="Greens")
+ ims_list.append(ims)
+ ims = np.stack(ims_list, axis=0)
+ return ims
diff --git a/egomimic/rldb/embodiment/eva.py b/egomimic/rldb/embodiment/eva.py
index 369b546c..e3a29e77 100644
--- a/egomimic/rldb/embodiment/eva.py
+++ b/egomimic/rldb/embodiment/eva.py
@@ -57,19 +57,28 @@ def viz_transformed_batch(cls, batch, mode=""):
)
@classmethod
- def viz(cls, images, actions, mode=Literal["traj", "axes"], intrinsics_key=None):
+ def viz(
+ cls,
+ images,
+ actions,
+ mode=Literal["traj", "axes"],
+ intrinsics_key=None,
+ **kwargs,
+ ):
intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY
if mode == "traj":
return _viz_traj(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
+ **kwargs,
)
if mode == "axes":
return _viz_axes(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
+ **kwargs,
)
raise ValueError(
f"Unsupported mode '{mode}'. Expected one of: " f"('traj', 'axes')."
diff --git a/egomimic/rldb/embodiment/human.py b/egomimic/rldb/embodiment/human.py
index 08e0e868..bab7432b 100644
--- a/egomimic/rldb/embodiment/human.py
+++ b/egomimic/rldb/embodiment/human.py
@@ -69,6 +69,7 @@ def viz(
actions,
mode=Literal["traj", "axes", "keypoints"],
intrinsics_key=None,
+ **kwargs,
):
intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY
if mode == "traj":
@@ -76,12 +77,14 @@ def viz(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
+ **kwargs,
)
if mode == "axes":
return _viz_axes(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
+ **kwargs,
)
if mode == "keypoints":
return _viz_keypoints(
@@ -91,6 +94,7 @@ def viz(
edges=cls.FINGER_EDGES,
colors=cls.FINGER_COLORS,
edge_ranges=cls.FINGER_EDGE_RANGES,
+ **kwargs,
)
raise ValueError(
f"Unsupported mode '{mode}'. Expected one of: "
diff --git a/egomimic/rldb/utils.py b/egomimic/rldb/utils.py
index ff9bcee7..a8e3300a 100644
--- a/egomimic/rldb/utils.py
+++ b/egomimic/rldb/utils.py
@@ -1043,7 +1043,7 @@ def sync_from_filters(
class DataSchematic(object):
- def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
+ def __init__(self, schematic_dict, norm_mode="zscore"):
"""
Initialize with a schematic dictionary and create a DataFrame.
@@ -1091,7 +1091,6 @@ def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
)
self.df = pd.DataFrame(rows)
- self._viz_img_key = {get_embodiment_id(k): v for k, v in viz_img_key.items()}
self.shapes_infered = False
self.norm_mode = norm_mode
self.norm_stats = {emb: {} for emb in self.embodiments}
@@ -1298,12 +1297,6 @@ def get_zarr_data(ds, col):
logger.info("[NormStats] Finished norm inference")
- def viz_img_key(self):
- """
- Get the key that should be used for offline visualization
- """
- return self._viz_img_key
-
def all_keys(self):
"""
Get all key names.
diff --git a/egomimic/rldb/zarr/test_norm_stats.py b/egomimic/rldb/zarr/test_norm_stats.py
index b07735d1..92283814 100644
--- a/egomimic/rldb/zarr/test_norm_stats.py
+++ b/egomimic/rldb/zarr/test_norm_stats.py
@@ -167,10 +167,9 @@ def test_infer_norm_from_dataset_legacy_matches_current_on_dummy_dataset() -> No
},
}
}
- viz_img_key = {"eva_bimanual": "observations.images.front_img_1"}
- legacy_schematic = _LegacyDataSchematic(schematic_dict, viz_img_key)
- current_schematic = DataSchematic(schematic_dict, viz_img_key)
+ legacy_schematic = _LegacyDataSchematic(schematic_dict)
+ current_schematic = DataSchematic(schematic_dict)
legacy_schematic.infer_norm_from_dataset_legacy(dataset)
current_schematic.infer_norm_from_dataset(
diff --git a/egomimic/rldb/zarr/utils.py b/egomimic/rldb/zarr/utils.py
index 43bc8ab6..c37e7e33 100644
--- a/egomimic/rldb/zarr/utils.py
+++ b/egomimic/rldb/zarr/utils.py
@@ -28,7 +28,7 @@ def set_global_seed(seed: int = 42):
class DataSchematic(object):
- def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
+ def __init__(self, schematic_dict, norm_mode="zscore"):
"""
Initialize with a schematic dictionary and create a DataFrame.
@@ -76,7 +76,6 @@ def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
)
self.df = pd.DataFrame(rows)
- self._viz_img_key = {get_embodiment_id(k): v for k, v in viz_img_key.items()}
self.shapes_infered = False
self.norm_mode = norm_mode
self.norm_stats = {emb: {} for emb in self.embodiments}
@@ -343,12 +342,6 @@ def infer_norm_from_dataset(
with open(benchmark_file, "w") as f:
json.dump(benchmark_stats, f, indent=4)
- def viz_img_key(self):
- """
- Get the key that should be used for offline visualization
- """
- return self._viz_img_key
-
def all_keys(self):
"""
Get all key names.
diff --git a/egomimic/scripts/tutorials/zarr_data_viz.ipynb b/egomimic/scripts/tutorials/zarr_data_viz.ipynb
index 7e2aebec..368c9603 100644
--- a/egomimic/scripts/tutorials/zarr_data_viz.ipynb
+++ b/egomimic/scripts/tutorials/zarr_data_viz.ipynb
@@ -1,5 +1,16 @@
{
"cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "79d184b3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
{
"cell_type": "markdown",
"id": "29aeeb40",
@@ -12,10 +23,19 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "32d9110f",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/coc/flash7/paphiwetsa3/projects/EgoVerse/.venv/lib/python3.11/site-packages/torch/cuda/__init__.py:61: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.\n",
+ " import pynvml # type: ignore[import]\n"
+ ]
+ }
+ ],
"source": [
"from pathlib import Path\n",
"\n",
@@ -42,25 +62,41 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"id": "cc9edba1",
"metadata": {},
"outputs": [],
"source": [
- "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\" # replace with your own temp directory for caching S3 data\n",
+ "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n",
"load_env()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"id": "a4aa1a05",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tables in schema 'app': ['episodes']\n"
+ ]
+ }
+ ],
"source": [
+ "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n",
+ "# EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/1767495035712.zarr\")\n",
+ "\n",
"key_map = Eva.get_keymap()\n",
"transform_list = Eva.get_transform_list()\n",
"\n",
+ "# Build a MultiDataset with exactly one ZarrDataset inside\n",
+ "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n",
+ "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n",
+ "\n",
+ "# multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n",
"resolver = S3EpisodeResolver(\n",
" TEMP_DIR, key_map=key_map, transform_list=transform_list\n",
")\n",
@@ -76,28 +112,90 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "4b72f3bb",
+ "execution_count": 7,
+ "id": "58f0af00",
"metadata": {},
"outputs": [],
+ "source": [
+ "batch = next(iter(loader))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "67a60218",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([-0.2326, 0.1783, 0.3866, -0.0258, 0.0405, 0.8205, 0.0800, 0.3351,\n",
+ " 0.2074, 0.4526, -0.0582, -0.0042, 0.8754, 0.0000],\n",
+ " dtype=torch.float64)"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch['actions_cartesian'][0,0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "4b72f3bb",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
 |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"# Separate YPR visualization preview\n",
"for batch in loader:\n",
- " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"palm_axes\")\n",
+ " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"axes\")\n",
" mpy.show_image(vis_ypr)\n",
" break"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"id": "0d8c3da2",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
"source": [
"images = []\n",
"for i, batch in enumerate(loader):\n",
- " vis = Eva.viz_transformed_batch(batch, mode=\"palm_traj\")\n",
+ " vis = Eva.viz_transformed_batch(batch, mode=\"traj\")\n",
" images.append(vis)\n",
" if i > 10:\n",
" break\n",
@@ -116,18 +214,28 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 17,
"id": "b7384468",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tables in schema 'app': ['episodes']\n"
+ ]
+ }
+ ],
"source": [
+ "temp_dir = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n",
+ "\n",
"intrinsics_key = \"base\"\n",
"\n",
- "key_map = Aria.get_keymap()\n",
- "transform_list = Aria.get_transform_list()\n",
+ "key_map = Aria.get_keymap(mode=\"keypoints\")\n",
+ "transform_list = Aria.get_transform_list(mode=\"keypoints\")\n",
"\n",
"resolver = S3EpisodeResolver(\n",
- " TEMP_DIR,\n",
+ " temp_dir,\n",
" key_map=key_map,\n",
" transform_list=transform_list,\n",
")\n",
@@ -144,20 +252,59 @@
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "af65095a",
+ "execution_count": 18,
+ "id": "607f219c",
"metadata": {},
"outputs": [],
+ "source": [
+ "batch = next(iter(loader))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "1c57c9f0",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "dict_keys(['observations.images.front_img_1', 'actions_keypoints', 'observations.state.keypoints', 'metadata.robot_name', 'embodiment'])"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "batch.keys()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "af65095a",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "KeyError",
+ "evalue": "'actions_cartesian'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+ "\u001b[31mKeyError\u001b[39m Traceback (most recent call last)",
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[13]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m ims = []\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m i, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(loader):\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m vis = \u001b[43mAria\u001b[49m\u001b[43m.\u001b[49m\u001b[43mviz_transformed_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m=\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mtraj\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m ims.append(vis)\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i > \u001b[32m10\u001b[39m:\n",
+ "\u001b[36mFile \u001b[39m\u001b[32m/coc/flash7/paphiwetsa3/projects/EgoVerse/egomimic/rldb/embodiment/human.py:9\u001b[39m, in \u001b[36mviz_transformed_batch\u001b[39m\u001b[34m(cls, batch, mode, action_key, image_key)\u001b[39m\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtyping\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Literal\n\u001b[32m 5\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01megomimic\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mrldb\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01membodiment\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01membodiment\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Embodiment\n\u001b[32m 6\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01megomimic\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mrldb\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mzarr\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01maction_chunk_transforms\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[32m 7\u001b[39m ActionChunkCoordinateFrameTransform,\n\u001b[32m 8\u001b[39m ConcatKeys,\n\u001b[32m----> \u001b[39m\u001b[32m9\u001b[39m DeleteKeys,\n\u001b[32m 10\u001b[39m InterpolatePose,\n\u001b[32m 11\u001b[39m PoseCoordinateFrameTransform,\n\u001b[32m 12\u001b[39m Reshape,\n\u001b[32m 13\u001b[39m Transform,\n\u001b[32m 14\u001b[39m XYZWXYZ_to_XYZYPR,\n\u001b[32m 15\u001b[39m )\n\u001b[32m 16\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01megomimic\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mviz_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[32m 17\u001b[39m _viz_axes,\n\u001b[32m 18\u001b[39m _viz_keypoints,\n\u001b[32m 19\u001b[39m _viz_traj,\n\u001b[32m 20\u001b[39m )\n\u001b[32m 21\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01megomimic\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mutils\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mtype_utils\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m _to_numpy\n",
+ "\u001b[31mKeyError\u001b[39m: 'actions_cartesian'"
+ ]
+ }
+ ],
"source": [
"ims = []\n",
"for i, batch in enumerate(loader):\n",
- " vis = Aria.viz_transformed_batch(batch, mode=\"palm_traj\")\n",
+ " vis = Aria.viz_transformed_batch(batch, mode=\"traj\")\n",
" ims.append(vis)\n",
- " # mpy.show_image(vis)\n",
- "\n",
- " # for k, v in batch.items():\n",
- " # print(f\"{k}: {tuple(v.shape)}\")\n",
- " \n",
" if i > 10:\n",
" break\n",
"\n",
@@ -174,7 +321,7 @@
"# Aria YPR video (same data loop, YPR overlay)\n",
"ims_ypr = []\n",
"for i, batch in enumerate(loader):\n",
- " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"palm_axes\")\n",
+ " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"axes\")\n",
" ims_ypr.append(vis_ypr)\n",
" if i > 20:\n",
" break\n",
@@ -182,6 +329,211 @@
"mpy.show_video(ims_ypr, fps=30)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "60723adf",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ " |
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "ims_keypoints = []\n",
+ "for i, batch in enumerate(loader):\n",
+ " vis_keypoints = Aria.viz_transformed_batch(batch, mode=\"keypoints\", action_key=\"actions_keypoints\")\n",
+ " ims_keypoints.append(vis_keypoints)\n",
+ " if i > 360:\n",
+ " break\n",
+ "\n",
+ "mpy.show_video(ims_keypoints, fps=20)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "efecaba7",
+ "metadata": {},
+ "source": [
+ "## Keypoint Visualization"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e39bca03",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Load Scale episode with raw keypoints (no action chunking needed)\n",
+ "\n",
+ "from egomimic.rldb.zarr.action_chunk_transforms import _xyzwxyz_to_matrix\n",
+ "\n",
+ "key_map_kp = {\n",
+ " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n",
+ " \"left.obs_keypoints\": {\"zarr_key\": \"left.obs_keypoints\"},\n",
+ " \"right.obs_keypoints\": {\"zarr_key\": \"right.obs_keypoints\"},\n",
+ " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n",
+ "}\n",
+ "\n",
+ "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"}\n",
+ "\n",
+ "resolver = S3EpisodeResolver(\n",
+ " temp_dir,\n",
+ " key_map=key_map\n",
+ ")\n",
+ "\n",
+ "cloudflare_ds = MultiDataset._from_resolver(\n",
+ " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n",
+ ")\n",
+ "\n",
+ "loader_kp = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "848c6d74",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ARIA Keypoint Viz\n",
+ "# MANO skeleton edges: (parent, child) for drawing bones\n",
+ "MANO_EDGES = [\n",
+ " (0, 1), (1, 2), (2, 3), (3, 4), # thumb\n",
+ " (0, 5), (5, 6), (6, 7), (7, 8), # index\n",
+ " (0, 9), (9, 10), (10, 11), (11, 12), # middle\n",
+ " (0, 13), (13, 14), (14, 15), (15, 16), # ring\n",
+ " (0, 17), (17, 18), (18, 19), (19, 20), # pinky\n",
+ "]\n",
+ "\n",
+ "# aria configuration\n",
+ "MANO_EDGES = [\n",
+ " (5, 6,), (6, 7), (7, 0), # thumb\n",
+ " (5, 8), (8, 9), (9, 10), (9, 1), # index\n",
+ " (5, 11), (11, 12), (12, 13), (13, 2), # middle\n",
+ " (5, 14), (14, 15), (15, 16), (16, 3), # ring\n",
+ " (5, 17), (17, 18), (18, 19), (19, 4), # pinky\n",
+ "]\n",
+ "\n",
+ "FINGER_COLORS = {\n",
+ " \"thumb\": (255, 100, 100), # red\n",
+ " \"index\": (100, 255, 100), # green\n",
+ " \"middle\": (100, 100, 255), # blue\n",
+ " \"ring\": (255, 255, 100), # yellow\n",
+ " \"pinky\": (255, 100, 255), # magenta\n",
+ "}\n",
+ "FINGER_EDGE_RANGES = [\n",
+ " (\"thumb\", 0, 3), (\"index\", 3, 6), (\"middle\", 6, 9),\n",
+ " (\"ring\", 9, 12), (\"pinky\", 12, 15),\n",
+ "]\n",
+ "\n",
+ "\n",
+ "def viz_keypoints(batch, image_key=\"observations.images.front_img_1\"):\n",
+ " \"\"\"Visualize all 21 MANO keypoints per hand, projected onto the image.\"\"\"\n",
+ " # Prepare image\n",
+ " img = batch[image_key][0].detach().cpu()\n",
+ " if img.shape[0] in (1, 3):\n",
+ " img = img.permute(1, 2, 0)\n",
+ " img_np = img.numpy()\n",
+ " if img_np.dtype != np.uint8:\n",
+ " if img_np.max() <= 1.0:\n",
+ " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n",
+ " else:\n",
+ " img_np = img_np.clip(0, 255).astype(np.uint8)\n",
+ " if img_np.shape[-1] == 1:\n",
+ " img_np = np.repeat(img_np, 3, axis=-1)\n",
+ "\n",
+ " intrinsics = INTRINSICS[\"base\"]\n",
+ " head_pose = batch[\"obs_head_pose\"][0].detach().cpu().numpy() # (6,)\n",
+ "\n",
+ " # T_head_world: camera pose in world (camera-to-world)\n",
+ " # We need world-to-camera = inv(T_head_world)\n",
+ " T_head_world = _xyzwxyz_to_matrix(head_pose[None, :])[0] # (4, 4)\n",
+ " T_world_to_cam = np.linalg.inv(T_head_world)\n",
+ "\n",
+ " vis = img_np.copy()\n",
+ " h, w = vis.shape[:2]\n",
+ "\n",
+ " for hand, dot_color in [(\"left\", (0, 120, 255)), (\"right\", (255, 80, 0))]:\n",
+ " kps_key = f\"{hand}.obs_keypoints\"\n",
+ " if kps_key not in batch:\n",
+ " continue\n",
+ " kps_flat = batch[kps_key][0].detach().cpu().numpy() # (63,)\n",
+ " kps_world = kps_flat.reshape(21, 3)\n",
+ "\n",
+ " # Skip if keypoints are all zero (invalid, clamped from 1e9)\n",
+ " if np.allclose(kps_world, 0.0, atol=1e-3):\n",
+ " continue\n",
+ "\n",
+ " # World -> camera frame\n",
+ " kps_h = np.concatenate([kps_world, np.ones((21, 1))], axis=1) # (21, 4)\n",
+ " kps_cam = (T_world_to_cam @ kps_h.T).T[:, :3] # (21, 3)\n",
+ "\n",
+ " # Camera frame -> pixels\n",
+ " kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (21, 3+)\n",
+ "\n",
+ " # Identify valid keypoints (z > 0 and in image bounds)\n",
+ " valid = (kps_cam[:, 2] > 0.01)\n",
+ " valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w)\n",
+ " valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h)\n",
+ "\n",
+ " # Draw skeleton edges (colored by finger)\n",
+ " for finger, start, end in FINGER_EDGE_RANGES:\n",
+ " color = FINGER_COLORS[finger]\n",
+ " for edge_idx in range(start, end):\n",
+ " i, j = MANO_EDGES[edge_idx]\n",
+ " if valid[i] and valid[j]:\n",
+ " p1 = (int(kps_px[i, 0]), int(kps_px[i, 1]))\n",
+ " p2 = (int(kps_px[j, 0]), int(kps_px[j, 1]))\n",
+ " cv2.line(vis, p1, p2, color, 2)\n",
+ "\n",
+ " # Draw keypoint dots on top\n",
+ " for k in range(21):\n",
+ " if valid[k]:\n",
+ " center = (int(kps_px[k, 0]), int(kps_px[k, 1]))\n",
+ " cv2.circle(vis, center, 4, dot_color, -1)\n",
+ " cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border\n",
+ "\n",
+ " # Label wrist\n",
+ " if valid[0]:\n",
+ " wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6)\n",
+ " cv2.putText(vis, f\"{hand[0].upper()}\", wrist_px,\n",
+ " cv2.FONT_HERSHEY_SIMPLEX, 0.5, dot_color, 2)\n",
+ "\n",
+ " return vis"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "75dbfa95",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Render keypoint video\n",
+ "ims_kp = []\n",
+ "for i, batch_kp in enumerate(loader_kp):\n",
+ " vis = viz_keypoints(batch_kp)\n",
+ " ims_kp.append(vis)\n",
+ " if i > 10:\n",
+ " break\n",
+ "\n",
+ "mpy.show_video(ims_kp, fps=30)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -193,7 +545,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "emimic (3.11.14)",
+ "display_name": ".venv",
"language": "python",
"name": "python3"
},
diff --git a/egomimic/scripts/zarr_data_viz.ipynb b/egomimic/scripts/zarr_data_viz.ipynb
deleted file mode 100644
index 1a1e4302..00000000
--- a/egomimic/scripts/zarr_data_viz.ipynb
+++ /dev/null
@@ -1,420 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "79d184b3",
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "29aeeb40",
- "metadata": {},
- "source": [
- "# Eva Data\n",
- "\n",
- "This notebook builds a `MultiDataset` containing exactly one `ZarrDataset`, loads one batch, visualizes one image with `mediapy`, and prints the rest of the batch."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "32d9110f",
- "metadata": {},
- "outputs": [],
- "source": [
- "from pathlib import Path\n",
- "\n",
- "import cv2\n",
- "import imageio_ffmpeg\n",
- "import mediapy as mpy\n",
- "import numpy as np\n",
- "import torch\n",
- "\n",
- "from egomimic.rldb.embodiment.eva import Eva\n",
- "from egomimic.rldb.embodiment.human import Aria\n",
- "from egomimic.rldb.zarr.zarr_dataset_multi import MultiDataset, ZarrDataset\n",
- "from egomimic.rldb.zarr.zarr_dataset_multi import S3EpisodeResolver\n",
- "from egomimic.utils.egomimicUtils import (\n",
- " INTRINSICS,\n",
- " cam_frame_to_cam_pixels,\n",
- " nds,\n",
- ")\n",
- "from egomimic.utils.aws.aws_data_utils import load_env\n",
- "\n",
- "# Ensure mediapy can find an ffmpeg executable in this environment\n",
- "mpy.set_ffmpeg(imageio_ffmpeg.get_ffmpeg_exe())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cc9edba1",
- "metadata": {},
- "outputs": [],
- "source": [
- "TEMP_DIR = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n",
- "load_env()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a4aa1a05",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Point this at a single episode directory, e.g. /path/to/episode_hash.zarr\n",
- "# EPISODE_PATH = Path(\"/coc/flash7/scratch/egoverseDebugDatasets/1767495035712.zarr\")\n",
- "\n",
- "key_map = Eva.get_keymap()\n",
- "transform_list = Eva.get_transform_list()\n",
- "\n",
- "# Build a MultiDataset with exactly one ZarrDataset inside\n",
- "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map, transform_list=transform_list)\n",
- "# single_ds = ZarrDataset(Episode_path=EPISODE_PATH, key_map=key_map)\n",
- "\n",
- "# multi_ds = MultiDataset(datasets={\"single_episode\": single_ds}, mode=\"total\")\n",
- "resolver = S3EpisodeResolver(\n",
- " TEMP_DIR, key_map=key_map, transform_list=transform_list\n",
- ")\n",
- "filters = {\n",
- " \"episode_hash\": \"2025-12-26-18-07-46-296000\"\n",
- "}\n",
- "multi_ds = MultiDataset._from_resolver(\n",
- " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n",
- ")\n",
- "\n",
- "loader = torch.utils.data.DataLoader(multi_ds, batch_size=1, shuffle=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4b72f3bb",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Separate YPR visualization preview\n",
- "for batch in loader:\n",
- " vis_ypr = Eva.viz_transformed_batch(batch, mode=\"axes\")\n",
- " mpy.show_image(vis_ypr)\n",
- " break"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0d8c3da2",
- "metadata": {},
- "outputs": [],
- "source": [
- "images = []\n",
- "for i, batch in enumerate(loader):\n",
- " vis = Eva.viz_transformed_batch(batch, mode=\"traj\")\n",
- " images.append(vis)\n",
- " if i > 10:\n",
- " break\n",
- "\n",
- "mpy.show_video(images, fps=30)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1a3382f1",
- "metadata": {},
- "source": [
- "## Human Datasets\n",
- "Mecka, Scale and Aria should all run exactly the same"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b7384468",
- "metadata": {},
- "outputs": [],
- "source": [
- "temp_dir = \"/coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest\"\n",
- "\n",
- "intrinsics_key = \"base\"\n",
- "\n",
- "key_map = Aria.get_keymap(mode=\"keypoints\")\n",
- "transform_list = Aria.get_transform_list(mode=\"keypoints\")\n",
- "\n",
- "resolver = S3EpisodeResolver(\n",
- " temp_dir,\n",
- " key_map=key_map,\n",
- " transform_list=transform_list,\n",
- ")\n",
- "\n",
- "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"} #aria\n",
- "# filters = {\"episode_hash\": \"692ee048ef7557106e6c4b8d\"} # mecka\n",
- "\n",
- "cloudflare_ds = MultiDataset._from_resolver(\n",
- " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n",
- ")\n",
- "\n",
- "loader = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "af65095a",
- "metadata": {},
- "outputs": [],
- "source": [
- "ims = []\n",
- "for i, batch in enumerate(loader):\n",
- " vis = Aria.viz_transformed_batch(batch, mode=\"traj\")\n",
- " ims.append(vis)\n",
- " if i > 10:\n",
- " break\n",
- "\n",
- "mpy.show_video(ims, fps=30)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e6d8d872",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Aria YPR video (same data loop, YPR overlay)\n",
- "ims_ypr = []\n",
- "for i, batch in enumerate(loader):\n",
- " vis_ypr = Aria.viz_transformed_batch(batch, mode=\"axes\")\n",
- " ims_ypr.append(vis_ypr)\n",
- " if i > 20:\n",
- " break\n",
- "\n",
- "mpy.show_video(ims_ypr, fps=30)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "60723adf",
- "metadata": {},
- "outputs": [],
- "source": [
- "ims_keypoints = []\n",
- "for i, batch in enumerate(loader):\n",
- " vis_keypoints = Aria.viz_transformed_batch(batch, mode=\"keypoints\")\n",
- " ims_keypoints.append(vis_keypoints)\n",
- " if i > 360:\n",
- " break\n",
- "\n",
- "mpy.show_video(ims_keypoints, fps=20)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "efecaba7",
- "metadata": {},
- "source": [
- "## Keypoint Visualization"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e39bca03",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load Scale episode with raw keypoints (no action chunking needed)\n",
- "\n",
- "from egomimic.rldb.zarr.action_chunk_transforms import _xyzwxyz_to_matrix\n",
- "\n",
- "key_map_kp = {\n",
- " \"images.front_1\": {\"zarr_key\": \"images.front_1\"},\n",
- " \"left.obs_keypoints\": {\"zarr_key\": \"left.obs_keypoints\"},\n",
- " \"right.obs_keypoints\": {\"zarr_key\": \"right.obs_keypoints\"},\n",
- " \"obs_head_pose\": {\"zarr_key\": \"obs_head_pose\"},\n",
- "}\n",
- "\n",
- "filters = {\"episode_hash\": \"2026-01-20-20-59-43-376000\"}\n",
- "\n",
- "resolver = S3EpisodeResolver(\n",
- " temp_dir,\n",
- " key_map=key_map\n",
- ")\n",
- "\n",
- "cloudflare_ds = MultiDataset._from_resolver(\n",
- " resolver, filters=filters, sync_from_s3=True, mode=\"total\"\n",
- ")\n",
- "\n",
- "loader_kp = torch.utils.data.DataLoader(cloudflare_ds, batch_size=1, shuffle=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "848c6d74",
- "metadata": {},
- "outputs": [],
- "source": [
- "# ARIA Keypoint Viz\n",
- "# MANO skeleton edges: (parent, child) for drawing bones\n",
- "MANO_EDGES = [\n",
- " (0, 1), (1, 2), (2, 3), (3, 4), # thumb\n",
- " (0, 5), (5, 6), (6, 7), (7, 8), # index\n",
- " (0, 9), (9, 10), (10, 11), (11, 12), # middle\n",
- " (0, 13), (13, 14), (14, 15), (15, 16), # ring\n",
- " (0, 17), (17, 18), (18, 19), (19, 20), # pinky\n",
- "]\n",
- "\n",
- "# aria configuration\n",
- "MANO_EDGES = [\n",
- " (5, 6,), (6, 7), (7, 0), # thumb\n",
- " (5, 8), (8, 9), (9, 10), (9, 1), # index\n",
- " (5, 11), (11, 12), (12, 13), (13, 2), # middle\n",
- " (5, 14), (14, 15), (15, 16), (16, 3), # ring\n",
- " (5, 17), (17, 18), (18, 19), (19, 4), # pinky\n",
- "]\n",
- "\n",
- "FINGER_COLORS = {\n",
- " \"thumb\": (255, 100, 100), # red\n",
- " \"index\": (100, 255, 100), # green\n",
- " \"middle\": (100, 100, 255), # blue\n",
- " \"ring\": (255, 255, 100), # yellow\n",
- " \"pinky\": (255, 100, 255), # magenta\n",
- "}\n",
- "FINGER_EDGE_RANGES = [\n",
- " (\"thumb\", 0, 3), (\"index\", 3, 6), (\"middle\", 6, 9),\n",
- " (\"ring\", 9, 12), (\"pinky\", 12, 15),\n",
- "]\n",
- "\n",
- "\n",
- "def viz_keypoints(batch, image_key=\"observations.images.front_img_1\"):\n",
- " \"\"\"Visualize all 21 MANO keypoints per hand, projected onto the image.\"\"\"\n",
- " # Prepare image\n",
- " img = batch[image_key][0].detach().cpu()\n",
- " if img.shape[0] in (1, 3):\n",
- " img = img.permute(1, 2, 0)\n",
- " img_np = img.numpy()\n",
- " if img_np.dtype != np.uint8:\n",
- " if img_np.max() <= 1.0:\n",
- " img_np = (img_np * 255.0).clip(0, 255).astype(np.uint8)\n",
- " else:\n",
- " img_np = img_np.clip(0, 255).astype(np.uint8)\n",
- " if img_np.shape[-1] == 1:\n",
- " img_np = np.repeat(img_np, 3, axis=-1)\n",
- "\n",
- " intrinsics = INTRINSICS[\"base\"]\n",
- " head_pose = batch[\"obs_head_pose\"][0].detach().cpu().numpy() # (6,)\n",
- "\n",
- " # T_head_world: camera pose in world (camera-to-world)\n",
- " # We need world-to-camera = inv(T_head_world)\n",
- " T_head_world = _xyzwxyz_to_matrix(head_pose[None, :])[0] # (4, 4)\n",
- " T_world_to_cam = np.linalg.inv(T_head_world)\n",
- "\n",
- " vis = img_np.copy()\n",
- " h, w = vis.shape[:2]\n",
- "\n",
- " for hand, dot_color in [(\"left\", (0, 120, 255)), (\"right\", (255, 80, 0))]:\n",
- " kps_key = f\"{hand}.obs_keypoints\"\n",
- " if kps_key not in batch:\n",
- " continue\n",
- " kps_flat = batch[kps_key][0].detach().cpu().numpy() # (63,)\n",
- " kps_world = kps_flat.reshape(21, 3)\n",
- "\n",
- " # Skip if keypoints are all zero (invalid, clamped from 1e9)\n",
- " if np.allclose(kps_world, 0.0, atol=1e-3):\n",
- " continue\n",
- "\n",
- " # World -> camera frame\n",
- " kps_h = np.concatenate([kps_world, np.ones((21, 1))], axis=1) # (21, 4)\n",
- " kps_cam = (T_world_to_cam @ kps_h.T).T[:, :3] # (21, 3)\n",
- "\n",
- " # Camera frame -> pixels\n",
- " kps_px = cam_frame_to_cam_pixels(kps_cam, intrinsics) # (21, 3+)\n",
- "\n",
- " # Identify valid keypoints (z > 0 and in image bounds)\n",
- " valid = (kps_cam[:, 2] > 0.01)\n",
- " valid &= (kps_px[:, 0] >= 0) & (kps_px[:, 0] < w)\n",
- " valid &= (kps_px[:, 1] >= 0) & (kps_px[:, 1] < h)\n",
- "\n",
- " # Draw skeleton edges (colored by finger)\n",
- " for finger, start, end in FINGER_EDGE_RANGES:\n",
- " color = FINGER_COLORS[finger]\n",
- " for edge_idx in range(start, end):\n",
- " i, j = MANO_EDGES[edge_idx]\n",
- " if valid[i] and valid[j]:\n",
- " p1 = (int(kps_px[i, 0]), int(kps_px[i, 1]))\n",
- " p2 = (int(kps_px[j, 0]), int(kps_px[j, 1]))\n",
- " cv2.line(vis, p1, p2, color, 2)\n",
- "\n",
- " # Draw keypoint dots on top\n",
- " for k in range(21):\n",
- " if valid[k]:\n",
- " center = (int(kps_px[k, 0]), int(kps_px[k, 1]))\n",
- " cv2.circle(vis, center, 4, dot_color, -1)\n",
- " cv2.circle(vis, center, 4, (255, 255, 255), 1) # white border\n",
- "\n",
- " # Label wrist\n",
- " if valid[0]:\n",
- " wrist_px = (int(kps_px[0, 0]) + 6, int(kps_px[0, 1]) - 6)\n",
- " cv2.putText(vis, f\"{hand[0].upper()}\", wrist_px,\n",
- " cv2.FONT_HERSHEY_SIMPLEX, 0.5, dot_color, 2)\n",
- "\n",
- " return vis"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "75dbfa95",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Render keypoint video\n",
- "ims_kp = []\n",
- "for i, batch_kp in enumerate(loader_kp):\n",
- " vis = viz_keypoints(batch_kp)\n",
- " ims_kp.append(vis)\n",
- " if i > 10:\n",
- " break\n",
- "\n",
- "mpy.show_video(ims_kp, fps=30)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8f4fbaec",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": ".venv",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.14"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py
index 8e4540e1..7f2644c5 100644
--- a/egomimic/trainHydra.py
+++ b/egomimic/trainHydra.py
@@ -120,10 +120,16 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
),
)
+ viz_func = cfg.visualization
+ viz_func_dict = {}
+ for embodiment_name, embodiment_viz_func in viz_func.items():
+ viz_func_dict[embodiment_name] = hydra.utils.instantiate(embodiment_viz_func)
+
# NOTE: We also pass the data_schematic_dict into the robomimic model's instatiation now that we've initialzied the shapes and norm stats. In theory, upon loading the PL checkpoint, it will remember this, but let's see.
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
- cfg.model, robomimic_model={"data_schematic": data_schematic}
+ cfg.model,
+ robomimic_model={"data_schematic": data_schematic, "viz_func": viz_func_dict},
)
_log_dataset_frame_counts(train_datasets, valid_datasets)
diff --git a/egomimic/utils/viz_utils.py b/egomimic/utils/viz_utils.py
index b6ecc408..bff862bd 100644
--- a/egomimic/utils/viz_utils.py
+++ b/egomimic/utils/viz_utils.py
@@ -10,6 +10,19 @@
from egomimic.utils.pose_utils import _split_action_pose, _split_keypoints
+class ColorPalette:
+ Blues = "Blues"
+ Greens = "Greens"
+ Reds = "Reds"
+ Oranges = "Oranges"
+ Purples = "Purples"
+ Greys = "Greys"
+
+ @classmethod
+ def is_valid(cls, name: str) -> bool:
+ return name in vars(cls).values()
+
+
def _prepare_viz_image(img):
if img.ndim == 3 and img.shape[0] in (1, 3):
img = np.transpose(img, (1, 2, 0))
@@ -28,7 +41,11 @@ def _prepare_viz_image(img):
return img
-def _viz_traj(images, actions, intrinsics_key):
+def _viz_traj(images, actions, intrinsics_key, **kwargs):
+ color = kwargs.get("color", "Blues")
+ if not ColorPalette.is_valid(color):
+ raise ValueError(f"Invalid color palette: {color}")
+
images = _prepare_viz_image(images)
intrinsics = INTRINSICS[intrinsics_key]
left_xyz, _, right_xyz, _ = _split_action_pose(actions)
@@ -54,7 +71,7 @@ def _viz_traj(images, actions, intrinsics_key):
return vis
-def _viz_axes(images, actions, intrinsics_key, axis_len_m=0.04):
+def _viz_axes(images, actions, intrinsics_key, axis_len_m=0.04, **kwargs):
images = _prepare_viz_image(images)
intrinsics = INTRINSICS[intrinsics_key]
left_xyz, left_ypr, right_xyz, right_ypr = _split_action_pose(actions)
@@ -88,7 +105,9 @@ def _draw_axis_color_legend(frame):
)
return frame
- def _draw_rotation_at_anchor(frame, xyz_seq, ypr_seq, label, anchor_color):
+ def _draw_rotation_at_anchor(
+ frame, xyz_seq, ypr_seq, label, anchor_color, **kwargs
+ ):
if len(xyz_seq) == 0 or len(ypr_seq) == 0:
return frame
@@ -141,7 +160,9 @@ def _draw_rotation_at_anchor(frame, xyz_seq, ypr_seq, label, anchor_color):
return vis
-def _viz_keypoints(images, actions, intrinsics_key, edges, colors, edge_ranges):
+def _viz_keypoints(
+ images, actions, intrinsics_key, edges, colors, edge_ranges, **kwargs
+):
"""Visualize all 21 MANO keypoints per hand, projected onto the image."""
# Prepare image
images = _prepare_viz_image(images)
@@ -198,3 +219,7 @@ def _viz_keypoints(images, actions, intrinsics_key, edges, colors, edge_ranges):
)
return vis
+
+
+def save_image(image: np.ndarray, path: str) -> None:
+ cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR))