Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions egomimic/algo/pi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import OrderedDict
import logging

import torch
import torch.nn as nn
Expand All @@ -10,10 +11,16 @@
import einops
from torchmetrics import MeanSquaredError

logger = logging.getLogger(__name__)
# Ensure logger propagates to root logger and has appropriate level
# Child loggers inherit from parent, but we explicitly set level to ensure INFO messages appear
logger.setLevel(logging.INFO)
logger.propagate = True # Explicitly enable propagation (default, but ensures it works)

from egomimic.models.hpt_nets import *
from egomimic.algo.algo import Algo

from egomimic.utils.egomimicUtils import draw_actions, draw_rotation_text
from egomimic.utils.egomimicUtils import draw_actions, draw_rotation_text, draw_annotation_text

from egomimic.utils.action_utils import *
import egomimic.utils.memory_utils as memutils
Expand Down Expand Up @@ -89,12 +96,9 @@ def __init__(
self.ac_keys = ac_keys

self.domains = domains

local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
self.device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.set_device(self.device)


self.device = None

self.camera_keys = {}
self.proprio_keys = {}
self.lang_keys = {}
Expand Down Expand Up @@ -134,7 +138,8 @@ def __init__(
fb_obj = arcfg.fallback
self.action_registry.register("*", default_ac_key, fb_obj)
self.action_registry.register("*", "*", fb_obj)


# Create the model
model_cfg = openpi.models.pi0_config.Pi0Config(
dtype=self.config.pytorch_training_precision,
action_dim=self.config.model.action_dim,
Expand All @@ -144,8 +149,9 @@ def __init__(
action_expert_variant=getattr(self.config.model, "action_expert_variant", "gemma_300m"),
pi05=getattr(config.model, "pi05", False),
)
self.model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(self.device)


self.model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg)

if self.config.pytorch_weight_path is not None:
model_path = os.path.join(self.config.pytorch_weight_path, "model.safetensors")
safetensors.torch.load_model(
Expand All @@ -154,6 +160,7 @@ def __init__(
self.nets = nn.ModuleDict()
self.nets["policy"] = self.model


@override
def process_batch_for_training(self, batch):
"""
Expand All @@ -173,15 +180,27 @@ def process_batch_for_training(self, batch):
key_name = self.data_schematic.lerobot_key_to_keyname(key, embodiment_id)
if key_name is not None:
processed_batch[embodiment_id][key_name] = value

# Carry through language tokenization tensors produced by collate_fn
for tk in ("tokenized_prompt", "tokenized_mask", "token_loss_mask", "token_ar_mask"):
if tk in _batch:
processed_batch[embodiment_id][tk] = _batch[tk]

ac_key = self.ac_keys[embodiment_id]
if len(processed_batch[embodiment_id][ac_key].shape) != 3:
raise ValueError("Action shape in batch is not 2")

B, S, _ = processed_batch[embodiment_id][ac_key].shape
device = processed_batch[embodiment_id][ac_key].device
processed_batch[embodiment_id]["pad_mask"] = torch.ones(B, S, 1, device=device)
processed_batch[embodiment_id]["pad_mask"] = torch.ones(B, S, 1, device=device)
processed_batch[embodiment_id] = self.data_schematic.normalize_data(processed_batch[embodiment_id], embodiment_id)

if not processed_batch:
raise ValueError(
f"No valid embodiments found in batch. Batch contained: {list(batch.keys())}, "
f"but ac_keys only has: {list(self.ac_keys.keys())}"
)

return processed_batch


Expand Down Expand Up @@ -210,7 +229,7 @@ def forward_training(self, batch):
if isinstance(losses, list | tuple):
losses = torch.stack(losses)
elif not isinstance(losses, torch.Tensor):
losses = torch.tensor(losses, device=device, dtype=torch.float32)
losses = torch.tensor(losses, device=action.device, dtype=torch.float32)

loss = losses.mean()

Expand Down Expand Up @@ -339,7 +358,11 @@ def visualize_preds(self, predictions, batch):

if self.is_6dof and ac_key == "actions_cartesian":
ims[b] = draw_rotation_text(ims[b], gt_rot[b][0], preds_rot[b][0], position=(340, 20))


if "annotations" in batch:
annotation = batch["annotations"][b]
ims[b] = draw_annotation_text(ims[b], annotation)

return ims


Expand All @@ -356,11 +379,13 @@ def compute_losses(self, predictions, batch):
loss_key_name: torch.Tensor (1)
"""
loss_dict = OrderedDict()
total_action_loss = torch.tensor(0.0, device=self.device)
total_action_loss = None

for embodiment_id, _batch in batch.items():
embodiment_name = get_embodiment(embodiment_id).lower()
bc_loss = predictions[f"{embodiment_name}_loss"]
if total_action_loss is None:
total_action_loss = torch.tensor(0.0, device=bc_loss.device)
total_action_loss += bc_loss
loss_dict[f"{embodiment_name}_loss"] = bc_loss # for logging

Expand Down Expand Up @@ -434,7 +459,14 @@ def _robomimic_to_pi_data(self, batch, cam_keys, proprio_keys, lang_keys, ac_key
for k in images.keys()
}

tokenized_prompt, tokenized_prompt_mask, token_ar_mask, token_loss_mask = _empty_lang_placeholders(B, device)
if not lang_keys:
tokenized_prompt, tokenized_prompt_mask, token_ar_mask, token_loss_mask = _empty_lang_placeholders(B, device)

else:
tokenized_prompt = batch["tokenized_prompt"].to(device)
tokenized_prompt_mask = batch["tokenized_mask"].to(device)
token_ar_mask = batch["token_ar_mask"].to(device)
token_loss_mask = batch["token_loss_mask"].to(device)

# ---- Wrap into simple observation (helpers) ----
observation = _SimpleObservation(
Expand Down
47 changes: 47 additions & 0 deletions egomimic/hydra_configs/data/cotrain_foldclothes.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper

train_datasets:
dataset1:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: train
embodiment: "eva_bimanual"
filters:
task: "fold clothes"
local_files_only: True
dataset2:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: train
embodiment: "aria_bimanual"
filters:
task: "fold clothes"
local_files_only: True

valid_datasets:
dataset1:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: valid
embodiment: "eva_bimanual"
filters:
task: "fold clothes"
local_files_only: True
dataset2:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: valid
embodiment: "aria_bimanual"
filters:
task: "fold clothes"
local_files_only: True

train_dataloader_params:
dataset1:
batch_size: 32
num_workers: 8

valid_dataloader_params:
dataset1:
batch_size: 32
num_workers: 8
47 changes: 47 additions & 0 deletions egomimic/hydra_configs/data/cotrain_obj_cont.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper

train_datasets:
dataset1:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: train
embodiment: "eva_right_arm"
filters:
task: "object in container"
local_files_only: True
dataset2:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: train
embodiment: "aria_right_arm"
filters:
task: "object in container"
local_files_only: True

valid_datasets:
dataset1:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: valid
embodiment: "eva_right_arm"
filters:
task: "object in container"
local_files_only: True
dataset2:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: valid
embodiment: "aria_right_arm"
filters:
task: "object in container"
local_files_only: True

train_dataloader_params:
dataset1:
batch_size: 32
num_workers: 8

valid_dataloader_params:
dataset1:
batch_size: 32
num_workers: 8
6 changes: 2 additions & 4 deletions egomimic/hydra_configs/data/eva_bc_s3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ train_datasets:
dataset1:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
temp_root: "/coc/flash7/rpunamiya6/Projects/local_egoverse"
mode: train
embodiment: "eva_right_arm"
filters:
Expand All @@ -15,7 +14,6 @@ valid_datasets:
dataset1:
_target_: egomimic.rldb.utils.S3RLDBDataset
bucket_name: "rldb"
temp_root: "/coc/flash7/rpunamiya6/Projects/local_egoverse"
mode: valid
embodiment: "eva_right_arm"
filters:
Expand All @@ -25,9 +23,9 @@ valid_datasets:
train_dataloader_params:
dataset1:
batch_size: 32
num_workers: 10
num_workers: 8

valid_dataloader_params:
dataset1:
batch_size: 32
num_workers: 10
num_workers: 8
32 changes: 32 additions & 0 deletions egomimic/hydra_configs/data/mecka_midtrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper

train_datasets:
dataset1:
_target_: rldb.utils.S3RLDBDataset
bucket_name: "rldb"
main_prefix: "mecka"
mode: train
embodiment: "mecka_bimanual"
filters:
lab: "mecka"
local_files_only: True

valid_datasets:
dataset1:
_target_: rldb.utils.S3RLDBDataset
bucket_name: "rldb"
mode: valid
embodiment: "mecka_bimanual"
filters:
lab: "mecka"
local_files_only: True

train_dataloader_params:
dataset1:
batch_size: 32
num_workers: 8

valid_dataloader_params:
dataset1:
batch_size: 32
num_workers: 8
8 changes: 4 additions & 4 deletions egomimic/hydra_configs/data/mecka_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ train_datasets:
repo_id: "mecka_test"
mode: train
embodiment: "mecka_bimanual"
root: "/coc/flash7/acheluva3/EgoVerse/mecka_demo"
root: "/storage/cedar/cedar0/cedarp-dxu345-0/datasets/egoverse/mecka_tests"
local_files_only: True

valid_datasets:
Expand All @@ -15,15 +15,15 @@ valid_datasets:
repo_id: "mecka_test"
mode: valid
embodiment: "mecka_bimanual"
root: "/coc/flash7/acheluva3/EgoVerse/mecka_demo"
root: "/storage/cedar/cedar0/cedarp-dxu345-0/datasets/egoverse/mecka_tests"
local_files_only: True

train_dataloader_params:
dataset1:
batch_size: 32
num_workers: 10
num_workers: 8

valid_dataloader_params:
dataset1:
batch_size: 32
num_workers: 10
num_workers: 8
18 changes: 0 additions & 18 deletions egomimic/hydra_configs/hydra/launcher/submitit.yaml

This file was deleted.

19 changes: 19 additions & 0 deletions egomimic/hydra_configs/hydra/launcher/submitit_pace.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
defaults:
- submitit_slurm

_target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher

# Slurm configuration
name: ${hydra.job.name} # Default job name
partition: "gpu-h200" # Slurm partition
account: "gts-dxu345-rl2" # Slurm account
cpus_per_task: 8 # Number of CPUs per task (max 4:1 CPU:GPU ratio)
nodes: ${launch_params.nodes} # Number of nodes
tasks_per_node: ${launch_params.gpus_per_node} # Use variable for tasks per node
gres: "gpu:h200:${eval:'${launch_params.gpus_per_node} * ${launch_params.nodes}'}" # GPU type and count (h100 for H100 GPUs)
qos: "inferno" # Slurm QoS
mem_per_gpu: 250G
timeout_min: 2880 # Timeout in minutes (48 hours)
# exclude: "protocol, puma" # Nodes to exclude
additional_parameters:
requeue: true
Loading