From b6c8f33b3a022c42833ca6a23d9af6ad5c600b7a Mon Sep 17 00:00:00 2001 From: AnikethCheluva Date: Sun, 8 Mar 2026 22:23:28 -0400 Subject: [PATCH] adding support for single arm training --- .../hydra_configs/data/eva_right_arm.yaml | 41 +++++ .../model/hpt_bc_flow_eva_right_arm.yaml | 142 ++++++++++++++++++ egomimic/hydra_configs/train_zarr.yaml | 60 ++++++-- egomimic/rldb/embodiment/eva.py | 125 +++++++++++---- 4 files changed, 326 insertions(+), 42 deletions(-) create mode 100644 egomimic/hydra_configs/data/eva_right_arm.yaml create mode 100644 egomimic/hydra_configs/model/hpt_bc_flow_eva_right_arm.yaml diff --git a/egomimic/hydra_configs/data/eva_right_arm.yaml b/egomimic/hydra_configs/data/eva_right_arm.yaml new file mode 100644 index 00000000..67d6f530 --- /dev/null +++ b/egomimic/hydra_configs/data/eva_right_arm.yaml @@ -0,0 +1,41 @@ +_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper +train_datasets: + eva_right_arm: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver + resolver: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.S3EpisodeResolver + folder_path: /coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest/ + key_map: + _target_: egomimic.rldb.embodiment.eva.Eva.get_keymap + arm_mode: right_arm + transform_list: + _target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list + arm_mode: right_arm + filters: + robot_name: "eva_right_arm" + mode: total + +valid_datasets: + eva_right_arm: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver + resolver: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.S3EpisodeResolver + folder_path: /coc/flash7/scratch/egoverseDebugDatasets/egoverseS3DatasetTest/ + key_map: + _target_: egomimic.rldb.embodiment.eva.Eva.get_keymap + arm_mode: right_arm + transform_list: + _target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list + arm_mode: right_arm + filters: + robot_name: "eva_right_arm" + mode: total + +train_dataloader_params: + eva_right_arm: + batch_size: 32 + num_workers: 10 +valid_dataloader_params: + eva_right_arm: + batch_size: 32 + num_workers: 10 diff --git a/egomimic/hydra_configs/model/hpt_bc_flow_eva_right_arm.yaml b/egomimic/hydra_configs/model/hpt_bc_flow_eva_right_arm.yaml new file mode 100644 index 00000000..9a3e37c1 --- /dev/null +++ b/egomimic/hydra_configs/model/hpt_bc_flow_eva_right_arm.yaml @@ -0,0 +1,142 @@ +_target_: egomimic.pl_utils.pl_model.ModelWrapper +robomimic_model: + _target_: egomimic.algo.hpt.HPT + data_schematic: _${data.dataset.data_schematic} + camera_transforms: + eva_right_arm: + _target_: egomimic.utils.egomimicUtils.CameraTransforms + intrinsics_key: "base" # change to base_half if using half res + extrinsics_key: "x5Dec13_2" + + diffusion: true + 6dof: true + + ac_keys: + eva_right_arm: "actions_cartesian" + + kinematics_solver: + _target_: egomimic.robot.eva.eva_kinematics.EvaMinkKinematicsSolver + model_path: /coc/flash7/rpunamiya6/Projects/EgoVerse/egomimic/resources/model_x5.xml + + trunk: + embed_dim: 256 + num_blocks: 16 + num_heads: 8 + token_postprocessing: "action_token" + observation_horizon: 1 + action_horizon: 64 + no_trunk: false + use_domain_embedding: true + drop_path: 0.1 + weight_init_style: "pytorch" + + multitask: false + pretrained: false + pretrained_checkpoint: "" # TODO + reverse_kl_samples: 8 + + domains: ["eva_right_arm"] + shared_obs_keys: ["front_img_1"] + + shared_stem_specs: + front_img_1: + _target_: egomimic.models.hpt_nets.MLPPolicyStem + input_dim: 256 + output_dim: 256 + widths: [256] + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + + stem_specs: + eva_right_arm: + right_wrist_img: + _target_: egomimic.models.hpt_nets.MLPPolicyStem + input_dim: 256 + output_dim: 256 + widths: [256] + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + state_ee_pose: + _target_: egomimic.models.hpt_nets.MLPPolicyStem + input_dim: 7 # single arm: xyz + ypr + gripper (bimanual would be 14) + output_dim: 256 + widths: [256] + specs: + random_horizon_masking: false + cross_attn: + crossattn_latent: 16 + crossattn_heads: 8 + crossattn_dim_head: 64 + crossattn_modality_dropout: 0.1 + modality_embed_dim: 256 + + head_specs: + eva_right_arm: + _target_: egomimic.models.fm_policy.FMPolicy + action_horizon: 100 + num_inference_steps: 50 + pooling: null + time_dist: "beta" + infer_ac_dims: + eva_right_arm: 7 + model: + _target_: egomimic.models.denoising_nets.CrossTransformer + nblocks: 6 + cond_dim: 256 + hidden_dim: 128 + act_dim: 7 + act_seq: 100 + n_heads: 4 + dropout: 0.1 + mlp_layers: 4 + mlp_ratio: 4 + + encoder_specs: + front_img_1: + _target_: egomimic.models.hpt_nets.ResNet + output_dim: 256 + right_wrist_img: + _target_: egomimic.models.hpt_nets.ResNet + output_dim: 256 + + train_image_augs: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.ColorJitter + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + hue: 0.05 + - _target_: torchvision.transforms.Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + eval_image_augs: + _target_: torchvision.transforms.Compose + transforms: + - _target_: torchvision.transforms.Normalize + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 1e-4 + weight_decay: 0.0001 + +scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + _partial_: true + T_max: 1400 + eta_min: 1e-5 diff --git a/egomimic/hydra_configs/train_zarr.yaml b/egomimic/hydra_configs/train_zarr.yaml index 0cc53a23..9f5a07af 100644 --- a/egomimic/hydra_configs/train_zarr.yaml +++ b/egomimic/hydra_configs/train_zarr.yaml @@ -6,7 +6,7 @@ defaults: - logger: wandb - data: eva - callbacks: checkpoints - - override hydra/launcher: submitit + - override hydra/launcher: submitit_skynet - _self_ name: test @@ -62,6 +62,50 @@ data_schematic: # Dynamically fill in these shapes from the dataset embodiment: key_type: metadata_keys zarr_key: metadata.embodiment + eva_right_arm: + front_img_1: #batch key + key_type: camera_keys # key type + zarr_key: observations.images.front_img_1 # dataset key + right_wrist_img: + key_type: camera_keys + zarr_key: observations.images.right_wrist_img + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose + joint_positions: + key_type: proprio_keys + zarr_key: observations.state.joint_positions + actions_joints: + key_type: action_keys + zarr_key: actions_joints + actions_cartesian: + key_type: action_keys + zarr_key: actions_cartesian + embodiment: + key_type: metadata_keys + zarr_key: metadata.embodiment + eva_left_arm: + front_img_1: #batch key + key_type: camera_keys # key type + zarr_key: observations.images.front_img_1 # dataset key + left_wrist_img: + key_type: camera_keys + zarr_key: observations.images.left_wrist_img + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose + joint_positions: + key_type: proprio_keys + zarr_key: observations.state.joint_positions + actions_joints: + key_type: action_keys + zarr_key: actions_joints + actions_cartesian: + key_type: action_keys + zarr_key: actions_cartesian + embodiment: + key_type: metadata_keys + zarr_key: metadata.embodiment aria_bimanual: front_img_1: key_type: camera_keys @@ -102,13 +146,11 @@ data_schematic: # Dynamically fill in these shapes from the dataset key_type: metadata_keys zarr_key: metadata.embodiment viz_img_key: - eva_bimanual: - front_img_1 - aria_bimanual: - front_img_1 - mecka_bimanual: - front_img_1 - scale_bimanual: - front_img_1 + eva_right_arm: front_img_1 + eva_left_arm: front_img_1 + eva_bimanual: front_img_1 + aria_bimanual: front_img_1 + mecka_bimanual: front_img_1 + scale_bimanual: front_img_1 seed: 42 diff --git a/egomimic/rldb/embodiment/eva.py b/egomimic/rldb/embodiment/eva.py index 369b546c..944dd6fb 100644 --- a/egomimic/rldb/embodiment/eva.py +++ b/egomimic/rldb/embodiment/eva.py @@ -32,8 +32,11 @@ class Eva(Embodiment): VIZ_IMAGE_KEY = "observations.images.front_img_1" @staticmethod - def get_transform_list() -> list[Transform]: - return _build_eva_bimanual_transform_list() + def get_transform_list( + arm_mode: Literal["bimanual", "left_arm", "right_arm"] = "bimanual", + **kwargs, + ) -> list[Transform]: + return _build_eva_bimanual_transform_list(arm_mode=arm_mode, **kwargs) @classmethod def viz_transformed_batch(cls, batch, mode=""): @@ -76,8 +79,12 @@ def viz(cls, images, actions, mode=Literal["traj", "axes"], intrinsics_key=None) ) @classmethod - def get_keymap(cls): - return { + def get_keymap( + cls, arm_mode: Literal["bimanual", "left_arm", "right_arm"] = "bimanual" + ): + """Return key_map for zarr loading. For single-arm modes, omits the other arm's keys so + datasets without right/left pose data load correctly.""" + base = { cls.VIZ_IMAGE_KEY: { "key_type": "camera_keys", "zarr_key": "images.front_1", @@ -127,6 +134,23 @@ def get_keymap(cls): "horizon": 45, }, } + if arm_mode == "left_arm": + drop = { + "right.obs_ee_pose", + "right.obs_gripper", + "right.gripper", + "right.cmd_ee_pose", + } + return {k: v for k, v in base.items() if k not in drop} + if arm_mode == "right_arm": + drop = { + "left.obs_ee_pose", + "left.obs_gripper", + "left.gripper", + "left.cmd_ee_pose", + } + return {k: v for k, v in base.items() if k not in drop} + return base def _build_eva_bimanual_transform_list( @@ -149,6 +173,7 @@ def _build_eva_bimanual_transform_list( stride: int = 1, extrinsics_key: str = "x5Dec13_2", is_quat: bool = True, + arm_mode: Literal["bimanual", "left_arm", "right_arm"] = "bimanual", ) -> list[Transform]: """Canonical EVA bimanual transform pipeline used by tests and notebooks.""" extrinsics = EXTRINSICS[extrinsics_key] @@ -158,7 +183,8 @@ def _build_eva_bimanual_transform_list( right_extra_batch_key = {"right_extrinsics_pose": right_extrinsics_pose} mode = "xyzwxyz" if is_quat else "xyzypr" - transform_list = [ + + action_transforms = [ ActionChunkCoordinateFrameTransform( target_world=left_target_world, chunk_world=left_cmd_world, @@ -173,6 +199,9 @@ def _build_eva_bimanual_transform_list( extra_batch_key=right_extra_batch_key, mode=mode, ), + ] + + pose_transforms = [ PoseCoordinateFrameTransform( target_world=left_target_world, pose_world=left_obs_pose, @@ -185,6 +214,9 @@ def _build_eva_bimanual_transform_list( transformed_key_name=right_obs_pose, mode=mode, ), + ] + + interpolate_transforms = [ InterpolatePose( new_chunk_length=chunk_length, action_key=left_cmd_camframe, @@ -199,6 +231,9 @@ def _build_eva_bimanual_transform_list( stride=stride, mode=mode, ), + ] + + gripper_transforms = [ InterpolateLinear( new_chunk_length=chunk_length, action_key=left_gripper, @@ -213,47 +248,71 @@ def _build_eva_bimanual_transform_list( ), ] + quat_keys = [left_cmd_camframe, left_obs_pose, right_cmd_camframe, right_obs_pose] + grip_cam_concat_keys = [ + left_cmd_camframe, + left_gripper, + right_cmd_camframe, + right_gripper, + ] + obs_concat_keys = [ + left_obs_pose, + left_obs_gripper, + right_obs_pose, + right_obs_gripper, + ] + delete_keys = [ + left_cmd_world, + left_target_world, + right_cmd_world, + right_target_world, + ] + + if arm_mode == "right_arm": + action_transforms = action_transforms[1:] + pose_transforms = pose_transforms[1:] + interpolate_transforms = interpolate_transforms[1:] + gripper_transforms = gripper_transforms[1:] + quat_keys = quat_keys[2:] + grip_cam_concat_keys = grip_cam_concat_keys[2:] + obs_concat_keys = obs_concat_keys[2:] + delete_keys = delete_keys[2:] + elif arm_mode == "left_arm": + action_transforms = action_transforms[:1] + pose_transforms = pose_transforms[:1] + interpolate_transforms = interpolate_transforms[:1] + gripper_transforms = gripper_transforms[:1] + quat_keys = quat_keys[:2] + grip_cam_concat_keys = grip_cam_concat_keys[:2] + obs_concat_keys = obs_concat_keys[:2] + delete_keys = delete_keys[:2] + elif arm_mode == "bimanual": + pass + + transform_list = ( + action_transforms + + pose_transforms + + interpolate_transforms + + gripper_transforms + ) + if is_quat: - transform_list.append( - XYZWXYZ_to_XYZYPR( - keys=[ - left_cmd_camframe, - right_cmd_camframe, - left_obs_pose, - right_obs_pose, - ] - ) - ) + transform_list.append(XYZWXYZ_to_XYZYPR(keys=quat_keys)) transform_list.extend( [ ConcatKeys( - key_list=[ - left_cmd_camframe, - left_gripper, - right_cmd_camframe, - right_gripper, - ], + key_list=grip_cam_concat_keys, new_key_name=actions_key, delete_old_keys=True, ), ConcatKeys( - key_list=[ - left_obs_pose, - left_obs_gripper, - right_obs_pose, - right_obs_gripper, - ], + key_list=obs_concat_keys, new_key_name=obs_key, delete_old_keys=True, ), DeleteKeys( - keys_to_delete=[ - left_cmd_world, - right_cmd_world, - left_target_world, - right_target_world, - ] + keys_to_delete=delete_keys, ), NumpyToTensor( keys=[