Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 5 additions & 62 deletions egomimic/algo/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
STD_SCALE,
EinOpsRearrange,
download_from_huggingface,
draw_actions,
draw_rotation_text,
frechet_gaussian_over_time,
get_sinusoid_encoding_table,
reverse_kl_from_samples,
Expand Down Expand Up @@ -800,6 +798,7 @@ def __init__(
encoder_specs: dict = None,
domains: list = None,
auxiliary_ac_keys: dict = {},
viz_func: dict = None,
# ---------------------------
# Pretrained
# ---------------------------
Expand All @@ -812,6 +811,7 @@ def __init__(
):
self.nets = nn.ModuleDict()
self.data_schematic = data_schematic
self.viz_func = viz_func

self.camera_transforms = camera_transforms
self.train_image_augs = train_image_augs
Expand Down Expand Up @@ -1250,69 +1250,12 @@ def visualize_preds(self, predictions, batch):
Returns:
ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top
"""

if self.viz_func is None:
raise ValueError("viz_func is not set")
embodiment_id = batch["embodiment"][0].item()
embodiment_name = get_embodiment(embodiment_id).lower()
ac_key = self.ac_keys[embodiment_id]

viz_img_key = self.data_schematic.viz_img_key()[embodiment_id]
ims = (batch[viz_img_key].cpu().numpy().transpose((0, 2, 3, 1)) * 255).astype(
np.uint8
)
for key in batch:
if f"{embodiment_name}_{key}" in predictions:
preds = predictions[f"{embodiment_name}_{key}"]
gt = batch[key]

if self.is_6dof and ac_key == "actions_cartesian":
gt, gt_rot = self._extract_xyz(gt)
preds, preds_rot = self._extract_xyz(preds)

for b in range(ims.shape[0]):
if preds.shape[-1] == 7 or preds.shape[-1] == 14:
ac_type = "joints"
elif preds.shape[-1] == 3 or preds.shape[-1] == 6:
ac_type = "xyz"
else:
raise ValueError(
f"Unknown action type with shape {preds.shape}"
)

# Determine arm from embodiment name, not action shape
if "bimanual" in embodiment_name:
arm = "both"
elif "left" in embodiment_name:
arm = "left"
elif "right" in embodiment_name:
arm = "right"
else:
raise ValueError(f"Unknown embodiment name: {embodiment_name}")
ims[b] = draw_actions(
ims[b],
ac_type,
"Purples",
preds[b].cpu().numpy(),
self.camera_transforms[embodiment_name].extrinsics,
self.camera_transforms[embodiment_name].intrinsics,
arm=arm,
kinematics_solver=self.kinematics_solver,
)
ims[b] = draw_actions(
ims[b],
ac_type,
"Greens",
gt[b].cpu().numpy(),
self.camera_transforms[embodiment_name].extrinsics,
self.camera_transforms[embodiment_name].intrinsics,
arm=arm,
kinematics_solver=self.kinematics_solver,
)

if self.is_6dof and ac_key == "actions_cartesian":
ims[b] = draw_rotation_text(
ims[b], gt_rot[b][0], preds_rot[b][0], position=(340, 20)
)
return ims
return self.viz_func[embodiment_name](predictions, batch)

@override
def compute_losses(self, predictions, batch):
Expand Down
6 changes: 5 additions & 1 deletion egomimic/hydra_configs/data/eva_human_cotrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ train_datasets:
folder_path: /coc/flash7/scratch/egoverseDebugDatasets/aria
key_map:
_target_: egomimic.rldb.embodiment.human.Aria.get_keymap
mode: cartesian
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
mode: cartesian
filters:
episode_hash: "2025-09-20-17-47-54-000000"
mode: total
Expand All @@ -44,8 +46,10 @@ valid_datasets:
folder_path: /coc/flash7/scratch/egoverseDebugDatasets/aria
key_map:
_target_: egomimic.rldb.embodiment.human.Aria.get_keymap
mode: cartesian
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
mode: cartesian
filters:
episode_hash: "2025-09-20-17-47-54-000000"
mode: total
Expand All @@ -62,4 +66,4 @@ valid_dataloader_params:
num_workers: 10
aria_bimanual:
batch_size: 32
num_workers: 10
num_workers: 10
9 changes: 1 addition & 8 deletions egomimic/hydra_configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ train: true
eval: false

eval_class:
_target_ : egomimic.scripts.evaluation.Eve
_target_: egomimic.scripts.evaluation.Eve
mode: real
arm: both
eval_path: "./logs/eval/${name}_${now:%Y-%m-%d_%H-%M-%S}"
Expand Down Expand Up @@ -93,10 +93,3 @@ data_schematic: # Dynamically fill in these shapes from the dataset
embodiment:
key_type: metadata_keys
lerobot_key: metadata.embodiment
viz_img_key:
eva_bimanual:
front_img_1
aria_bimanual:
front_img_1
mecka_bimanual:
front_img_1
15 changes: 3 additions & 12 deletions egomimic/hydra_configs/train_zarr.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
defaults:
- model: hpt_bc_flow_eva
- model: hpt_cotrain_flow_shared_head
- visualization: eva_cartesian_aria_cartesian
- paths: default
- trainer: ddp
- debug: null
- logger: wandb
- data: eva
- data: eva_human_cotrain
- callbacks: checkpoints
- override hydra/launcher: submitit
- _self_
Expand Down Expand Up @@ -32,7 +33,6 @@ launch_params:
gpus_per_node: 1
nodes: 1


data_schematic: # Dynamically fill in these shapes from the dataset
_target_: egomimic.rldb.zarr.utils.DataSchematic
norm_mode: quantile
Expand Down Expand Up @@ -101,14 +101,5 @@ data_schematic: # Dynamically fill in these shapes from the dataset
embodiment:
key_type: metadata_keys
zarr_key: metadata.embodiment
viz_img_key:
eva_bimanual:
front_img_1
aria_bimanual:
front_img_1
mecka_bimanual:
front_img_1
scale_bimanual:
front_img_1

seed: 42
4 changes: 2 additions & 2 deletions egomimic/hydra_configs/trainer/debug.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ defaults:

strategy: ddp_find_unused_parameters_true
limit_train_batches: 5
limit_val_batches: 3
limit_val_batches: 20
check_val_every_n_epoch: 2
profiler: simple
max_epochs: 4
min_epochs: 4
min_epochs: 4
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
eva_bimanual:
_target_: egomimic.rldb.embodiment.eva.Eva.viz_cartesian_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_cartesian
aria_bimanual:
_target_: egomimic.rldb.embodiment.human.Human.viz_cartesian_gt_preds
_partial_: true
image_key: front_img_1
action_key: actions_cartesian
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
eva_bimanual:
action_keys: actions_cartesian
viz_function:
_target_: egomimic.rldb.embodiment.eva.Eva.viz
_partial_: true
mode: traj
intrinsics_key: base_half
aria_bimanual:
action_keys: actions_cartesian
viz_function:
_target_: egomimic.rldb.embodiment.human.Aria.viz
_partial_: true
mode: keypoints
intrinsics_key: base_half
25 changes: 25 additions & 0 deletions egomimic/rldb/embodiment/embodiment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from abc import ABC
from enum import Enum

import numpy as np

from egomimic.rldb.zarr.action_chunk_transforms import Transform
from egomimic.utils.type_utils import _to_numpy


class EMBODIMENT(Enum):
Expand Down Expand Up @@ -53,3 +56,25 @@ def viz_transformed_batch(batch):
def get_keymap():
"""Returns a dictionary mapping from the raw keys in the dataset to the canonical keys used by the model."""
raise NotImplementedError

@classmethod
def viz_cartesian_gt_preds(cls, predictions, batch, image_key, action_key):
embodiment_id = batch["embodiment"][0].item()
embodiment_name = get_embodiment(embodiment_id).lower()

images = batch[image_key]
actions = batch[action_key]
pred_actions = predictions[f"{embodiment_name}_{action_key}"]
ims_list = []
images = _to_numpy(images)
actions = _to_numpy(actions)
pred_actions = _to_numpy(pred_actions)
for i in range(images.shape[0]):
image = images[i]
action = actions[i]
pred_action = pred_actions[i]
ims = cls.viz(image, action, mode="traj", color="Reds")
ims = cls.viz(ims, pred_action, mode="traj", color="Greens")
ims_list.append(ims)
ims = np.stack(ims_list, axis=0)
return ims
11 changes: 10 additions & 1 deletion egomimic/rldb/embodiment/eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,28 @@ def viz_transformed_batch(cls, batch, mode=""):
)

@classmethod
def viz(cls, images, actions, mode=Literal["traj", "axes"], intrinsics_key=None):
def viz(
cls,
images,
actions,
mode=Literal["traj", "axes"],
intrinsics_key=None,
**kwargs,
):
intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY
if mode == "traj":
return _viz_traj(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
**kwargs,
)
if mode == "axes":
return _viz_axes(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
**kwargs,
)
raise ValueError(
f"Unsupported mode '{mode}'. Expected one of: " f"('traj', 'axes')."
Expand Down
4 changes: 4 additions & 0 deletions egomimic/rldb/embodiment/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,22 @@ def viz(
actions,
mode=Literal["traj", "axes", "keypoints"],
intrinsics_key=None,
**kwargs,
):
intrinsics_key = intrinsics_key or cls.VIZ_INTRINSICS_KEY
if mode == "traj":
return _viz_traj(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
**kwargs,
)
if mode == "axes":
return _viz_axes(
images=images,
actions=actions,
intrinsics_key=intrinsics_key,
**kwargs,
)
if mode == "keypoints":
return _viz_keypoints(
Expand All @@ -91,6 +94,7 @@ def viz(
edges=cls.FINGER_EDGES,
colors=cls.FINGER_COLORS,
edge_ranges=cls.FINGER_EDGE_RANGES,
**kwargs,
)
raise ValueError(
f"Unsupported mode '{mode}'. Expected one of: "
Expand Down
9 changes: 1 addition & 8 deletions egomimic/rldb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,7 +1043,7 @@ def sync_from_filters(


class DataSchematic(object):
def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
def __init__(self, schematic_dict, norm_mode="zscore"):
"""
Initialize with a schematic dictionary and create a DataFrame.

Expand Down Expand Up @@ -1091,7 +1091,6 @@ def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
)

self.df = pd.DataFrame(rows)
self._viz_img_key = {get_embodiment_id(k): v for k, v in viz_img_key.items()}
self.shapes_infered = False
self.norm_mode = norm_mode
self.norm_stats = {emb: {} for emb in self.embodiments}
Expand Down Expand Up @@ -1298,12 +1297,6 @@ def get_zarr_data(ds, col):

logger.info("[NormStats] Finished norm inference")

def viz_img_key(self):
"""
Get the key that should be used for offline visualization
"""
return self._viz_img_key

def all_keys(self):
"""
Get all key names.
Expand Down
5 changes: 2 additions & 3 deletions egomimic/rldb/zarr/test_norm_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,9 @@ def test_infer_norm_from_dataset_legacy_matches_current_on_dummy_dataset() -> No
},
}
}
viz_img_key = {"eva_bimanual": "observations.images.front_img_1"}

legacy_schematic = _LegacyDataSchematic(schematic_dict, viz_img_key)
current_schematic = DataSchematic(schematic_dict, viz_img_key)
legacy_schematic = _LegacyDataSchematic(schematic_dict)
current_schematic = DataSchematic(schematic_dict)

legacy_schematic.infer_norm_from_dataset_legacy(dataset)
current_schematic.infer_norm_from_dataset(
Expand Down
9 changes: 1 addition & 8 deletions egomimic/rldb/zarr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def set_global_seed(seed: int = 42):


class DataSchematic(object):
def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
def __init__(self, schematic_dict, norm_mode="zscore"):
"""
Initialize with a schematic dictionary and create a DataFrame.

Expand Down Expand Up @@ -76,7 +76,6 @@ def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
)

self.df = pd.DataFrame(rows)
self._viz_img_key = {get_embodiment_id(k): v for k, v in viz_img_key.items()}
self.shapes_infered = False
self.norm_mode = norm_mode
self.norm_stats = {emb: {} for emb in self.embodiments}
Expand Down Expand Up @@ -343,12 +342,6 @@ def infer_norm_from_dataset(
with open(benchmark_file, "w") as f:
json.dump(benchmark_stats, f, indent=4)

def viz_img_key(self):
"""
Get the key that should be used for offline visualization
"""
return self._viz_img_key

def all_keys(self):
"""
Get all key names.
Expand Down
Loading