Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions egomimic/algo/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,13 +970,13 @@ def process_batch_for_training(self, batch):
embodiment: torch.Size([])
"""
processed_batch = {}

for embodiment_name, _batch in batch.items():
embodiment_id = get_embodiment_id(embodiment_name)
processed_batch[embodiment_id] = {}
for key, value in _batch.items():
key_name = self.data_schematic.zarr_key_to_keyname(key, embodiment_id)
if key is not None:
processed_batch[embodiment_id][key] = value
processed_batch[embodiment_id][key_name] = value

ac_key = self.ac_keys[embodiment_id]
if len(processed_batch[embodiment_id][ac_key].shape) != 3:
Expand All @@ -987,11 +987,12 @@ def process_batch_for_training(self, batch):
processed_batch[embodiment_id]["pad_mask"] = torch.ones(
B, S, 1, device=device
)

processed_batch[embodiment_id] = self.data_schematic.normalize_data(
processed_batch[embodiment_id], embodiment_id
)
processed_batch[embodiment_id]["embodiment"] = torch.tensor(
[embodiment_id], device=self.device, dtype=torch.int64
[embodiment_id], device=self.device, dtype=torch.int64
)

return processed_batch
Expand All @@ -1010,7 +1011,10 @@ def forward_training(self, batch):
predictions = OrderedDict()
hpt_batches = {}
self.training_step += 1
for embodiment_id, _batch in batch.items(): # TODO why don't we use batch with embodiment_name to keep things consistent
for (
embodiment_id,
_batch,
) in batch.items():
embodiment_name = get_embodiment(embodiment_id).lower()
cam_keys = self.camera_keys[embodiment_id]
proprio_keys = self.proprio_keys[embodiment_id]
Expand Down Expand Up @@ -1251,7 +1255,7 @@ def visualize_preds(self, predictions, batch):
Returns:
ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top
"""

embodiment_id = batch["embodiment"][0].item()
embodiment_name = get_embodiment(embodiment_id).lower()
ac_key = self.ac_keys[embodiment_id]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ train_datasets:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver
folder_path: /nethome/paphiwetsa3/flash/datasets/proc_zarr
key_map:
front_img_1: #batch key
observations.images.front_img_1: #batch key
key_type: camera_keys # key type
zarr_key: front_img_1
actions_cartesian:
Expand All @@ -24,7 +24,7 @@ valid_datasets:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver
folder_path: /nethome/paphiwetsa3/flash/datasets/proc_zarr
key_map:
front_img_1: #batch key
observations.images.front_img_1: #batch key
key_type: camera_keys # key type
zarr_key: front_img_1
actions_cartesian:
Expand Down
107 changes: 107 additions & 0 deletions egomimic/hydra_configs/data/eva_bc_zarr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper
train_datasets:
eva_bimanual:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset
datasets:
single_episode:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.ZarrDataset
Episode_path: /coc/flash7/scratch/egoverseDebugDatasets/eva/1767495035712.zarr
key_map:
observations.images.front_img_1:
key_type: camera_keys
zarr_key: images.front_1
observations.images.right_wrist_img:
key_type: camera_keys
zarr_key: images.right_wrist
observations.images.left_wrist_img:
key_type: camera_keys
zarr_key: images.left_wrist
right.obs_ee_pose:
key_type: proprio_keys
zarr_key: right.obs_ee_pose
right.obs_gripper:
key_type: proprio_keys
zarr_key: right.gripper
left.obs_ee_pose:
key_type: proprio_keys
zarr_key: left.obs_ee_pose
left.obs_gripper:
key_type: proprio_keys
zarr_key: left.gripper
right.gripper:
key_type: action_keys
zarr_key: right.gripper
horizon: 45
left.gripper:
key_type: action_keys
zarr_key: left.gripper
horizon: 45
right.cmd_ee_pose:
key_type: action_keys
zarr_key: right.cmd_ee_pose
horizon: 45
left.cmd_ee_pose:
key_type: action_keys
zarr_key: left.cmd_ee_pose
horizon: 45
transform_list:
_target_: egomimic.rldb.zarr.action_chunk_transforms.build_eva_bimanual_transform_list
mode: total

valid_datasets:
eva_bimanual:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset
datasets:
single_episode:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.ZarrDataset
Episode_path: /coc/flash7/scratch/egoverseDebugDatasets/eva/1767495035712.zarr
key_map:
observations.images.front_img_1 :
key_type: camera_keys
zarr_key: images.front_1
observations.images.right_wrist_img:
key_type: camera_keys
zarr_key: images.right_wrist
observations.images.left_wrist_img:
key_type: camera_keys
zarr_key: images.left_wrist
right.obs_ee_pose:
key_type: proprio_keys
zarr_key: right.obs_ee_pose
right.obs_gripper:
key_type: proprio_keys
zarr_key: right.gripper
left.obs_ee_pose:
key_type: proprio_keys
zarr_key: left.obs_ee_pose
left.obs_gripper:
key_type: proprio_keys
zarr_key: left.gripper
right.gripper:
key_type: action_keys
zarr_key: right.gripper
horizon: 45
left.gripper:
key_type: action_keys
zarr_key: left.gripper
horizon: 45
right.cmd_ee_pose:
key_type: action_keys
zarr_key: right.cmd_ee_pose
horizon: 45
left.cmd_ee_pose:
key_type: action_keys
zarr_key: left.cmd_ee_pose
horizon: 45
transform_list:
_target_: egomimic.rldb.zarr.action_chunk_transforms.build_eva_bimanual_transform_list
mode: total

train_dataloader_params:
eva_bimanual:
batch_size: 32
num_workers: 10
valid_dataloader_params:
eva_bimanual:
batch_size: 32
num_workers: 10
60 changes: 0 additions & 60 deletions egomimic/hydra_configs/data/zarr_test.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion egomimic/hydra_configs/model/hpt_bc_flow_eva.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ robomimic_model:
crossattn_dim_head: 64
crossattn_modality_dropout: 0.1
modality_embed_dim: 256
state_joint_positions:
state_ee_pose:
_target_: egomimic.models.hpt_nets.MLPPolicyStem
input_dim: 14
output_dim: 256
Expand Down
93 changes: 89 additions & 4 deletions egomimic/hydra_configs/train_zarr.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
defaults:
- model: hpt_bc_flow_aria
- model: hpt_bc_flow_eva
- paths: default
- trainer: ddp
- trainer: debug
- debug: null
- logger: wandb
- data: test_multi_zarr
- logger: debug
- data: eva_bc_zarr.yaml
- callbacks: checkpoints
- override hydra/launcher: submitit
- _self_
Expand All @@ -31,3 +31,88 @@ hydra:
launch_params:
gpus_per_node: 1
nodes: 1


data_schematic: # Dynamically fill in these shapes from the dataset
_target_: egomimic.rldb.zarr.utils.DataSchematic
norm_mode: quantile
schematic_dict:
eva_bimanual:
front_img_1: #batch key
key_type: camera_keys # key type
zarr_key: observations.images.front_img_1 # dataset key
right_wrist_img:
key_type: camera_keys
zarr_key: observations.images.right_wrist_img
left_wrist_img:
key_type: camera_keys
zarr_key: observations.images.left_wrist_img
ee_pose:
key_type: proprio_keys
zarr_key: observations.state.ee_pose
joint_positions:
key_type: proprio_keys
zarr_key: observations.state.joint_positions
actions_joints:
key_type: action_keys
zarr_key: actions_joints
actions_cartesian:
key_type: action_keys
zarr_key: actions_cartesian
embodiment:
key_type: metadata_keys
zarr_key: metadata.embodiment
aria_bimanual:
front_img_1:
key_type: camera_keys
zarr_key: observations.images.front_img_1
ee_pose:
key_type: proprio_keys
zarr_key: observations.state.ee_pose
actions_cartesian:
key_type: action_keys
zarr_key: actions_cartesian
embodiment:
key_type: metadata_keys
zarr_key: metadata.embodiment
mecka_bimanual:
front_img_1:
key_type: camera_keys
zarr_key: observations.images.front_img_1
ee_pose:
key_type: proprio_keys
zarr_key: observations.state.ee_pose_cam
actions_cartesian:
key_type: action_keys
zarr_key: actions_ee_cartesian_cam
actions_keypoints:
key_type: action_keys
zarr_key: actions_ee_keypoints_world
actions_head_cartesian:
key_type: action_keys
zarr_key: actions_head_cartesian_world
embodiment:
key_type: metadata_keys
zarr_key: metadata.embodiment
scale_bimanual:
front_img_1:
key_type: camera_keys
zarr_key: observations.images.front_img_1
ee_pose:
key_type: proprio_keys
zarr_key: observations.state.ee_pose
actions_cartesian:
key_type: action_keys
zarr_key: actions_cartesian
embodiment:
key_type: metadata_keys
zarr_key: metadata.embodiment
viz_img_key:
eva_bimanual:
front_img_1
aria_bimanual:
front_img_1
mecka_bimanual:
front_img_1
scale_bimanual:
front_img_1
Loading