Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions egomimic/algo/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,9 @@ def process_batch_for_training(self, batch):
"""
processed_batch = {}

for embodiment_id, _batch in batch.items():

for dataset_name, _batch in batch.items():
embodiment_id = _batch["robot_name"][0].item()
processed_batch[embodiment_id] = {}
for key, value in _batch.items():
key_name = self.data_schematic.lerobot_key_to_keyname(
Expand Down Expand Up @@ -1250,7 +1252,7 @@ def visualize_preds(self, predictions, batch):
Returns:
ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top
"""
embodiment_id = batch["embodiment"][0].item()
embodiment_id = batch["robot_name"][0].item()
embodiment_name = get_embodiment(embodiment_id).lower()
ac_key = self.ac_keys[embodiment_id]

Expand Down Expand Up @@ -1430,7 +1432,7 @@ def _robomimic_to_hpt_data(

data["is_6dof"] = self.is_6dof
data["pad_mask"] = batch["pad_mask"]
data["embodiment"] = batch["embodiment"]
data["embodiment"] = batch["robot_name"]

for aux_ac_key in aux_ac_keys:
data[aux_ac_key] = batch[aux_ac_key]
Expand Down
37 changes: 37 additions & 0 deletions egomimic/hydra_configs/data/zarr_multi_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper
train_datasets:
dataset1:
_target_: rldb.zarr.MultiDataset
datasets:
dataset1:
_target_: egomimic.rldb.zarr.ZarrDataset
Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr
action_horizon: 100
dataset2:
_target_: egomimic.rldb.zarr.ZarrDataset
Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr
action_horizon: 100
mode: train
valid_datasets:
dataset1:
_target_: rldb.zarr.MultiDataset
datasets:
dataset1:
_target_: egomimic.rldb.zarr.ZarrDataset
Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr
action_horizon: 100
dataset2:
_target_: egomimic.rldb.zarr.ZarrDataset
Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr
action_horizon: 100
mode: valid

train_dataloader_params:
dataset1:
batch_size: 2
num_workers: 10

valid_dataloader_params:
dataset1:
batch_size: 2
num_workers: 10
40 changes: 40 additions & 0 deletions egomimic/hydra_configs/data/zarr_resolver.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper
train_datasets:
dataset1:
_target_: egomimic.rldb.zarr.MultiDataset._from_resolver
resolver:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver
folder_path:
_target_: pathlib.Path
_args_: [/coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual]
embodiment: "eva_bimanual"
mode: "train"
valid_ratio: 0.2
sync_from_s3: false
filters:
is_deleted: false

valid_datasets:
dataset1:
_target_: egomimic.rldb.zarr.MultiDataset._from_resolver
resolver:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver
folder_path:
_target_: pathlib.Path
_args_: [/coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual]
embodiment: "eva_bimanual"
mode: "valid"
valid_ratio: 0.2
sync_from_s3: false
filters:
is_deleted: false

train_dataloader_params:
dataset1:
batch_size: 2
num_workers: 10

valid_dataloader_params:
dataset1:
batch_size: 2
num_workers: 10
23 changes: 23 additions & 0 deletions egomimic/hydra_configs/data/zarr_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper

train_datasets:
dataset1:
_target_: egomimic.rldb.zarr.ZarrDataset
Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr
action_horizon: 100

valid_datasets:
dataset1:
_target_: egomimic.rldb.zarr.ZarrDataset
Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr
action_horizon: 100

train_dataloader_params:
dataset1:
batch_size: 32
num_workers: 10

valid_dataloader_params:
dataset1:
batch_size: 32
num_workers: 10
6 changes: 6 additions & 0 deletions egomimic/hydra_configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ data_schematic: # Dynamically fill in these shapes from the dataset
embodiment:
key_type: metadata_keys
lerobot_key: metadata.embodiment
robot_name:
key_type: metadata_keys
lerobot_key: robot_name
aria_bimanual:
front_img_1:
key_type: camera_keys
Expand All @@ -74,6 +77,9 @@ data_schematic: # Dynamically fill in these shapes from the dataset
embodiment:
key_type: metadata_keys
lerobot_key: metadata.embodiment
robot_name:
key_type: metadata_keys
lerobot_key: robot_name
mecka_bimanual:
front_img_1:
key_type: camera_keys
Expand Down
8 changes: 5 additions & 3 deletions egomimic/pl_utils/pl_data_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch.utils.data import DataLoader, random_split, default_collate
from torch.utils.data import DataLoader, random_split, default_collate, ConcatDataset
from lightning.pytorch.utilities.combined_loader import CombinedLoader
from lightning import LightningDataModule
from transformers import AutoTokenizer
Expand Down Expand Up @@ -81,7 +81,8 @@ def train_dataloader(self):
iterables = dict()
for dataset_name, dataset in self.train_datasets.items():
dataset_params = self.train_dataloader_params.get(dataset_name, {})
iterables[dataset.embodiment] = DataLoader(
iterables[dataset_name] = DataLoader(

dataset,
shuffle=True,
collate_fn=self.collate_fn,
Expand All @@ -94,7 +95,8 @@ def val_dataloader(self):
iterables = dict()
for dataset_name, dataset in self.valid_datasets.items():
dataset_params = self.valid_dataloader_params.get(dataset_name, {})
iterables[dataset.embodiment] = DataLoader(
iterables[dataset_name] = DataLoader(

dataset,
shuffle=False,
collate_fn=self.collate_fn,
Expand Down
125 changes: 82 additions & 43 deletions egomimic/rldb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,7 @@ def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"):
"lerobot_key": key_info["lerobot_key"],
"shape": None,
"embodiment": embodiment_id,
"robot_name": embodiment_id,
}
)

Expand Down Expand Up @@ -1222,7 +1223,7 @@ def infer_shapes_from_batch(self, batch):
Updates:
The 'shape' column in the DataFrame is updated to match the inferred shapes (stored as tuples).
"""
embodiment_id = int(batch["metadata.embodiment"])
embodiment_id = int(batch["metadata.robot_name"])
for key, tensor in batch.items():
if hasattr(tensor, "shape"):
shape = tuple(tensor.shape)
Expand Down Expand Up @@ -1295,71 +1296,109 @@ def infer_norm_from_dataset(self, dataset):
returns: dictionary of means and stds for proprio and action keys
"""
norm_columns = []

embodiment = dataset.embodiment
if isinstance(embodiment, str):
embodiment = get_embodiment_id(embodiment)

norm_columns.extend(self.keys_of_type("proprio_keys"))
norm_columns.extend(self.keys_of_type("action_keys"))

logger.info(
f"[NormStats] Starting norm inference for embodiment={embodiment}, "
f"{len(norm_columns)} columns"
)
def _normalize_embodiment(value):
if value is None:
return None
if isinstance(value, str):
return get_embodiment_id(value)
return int(value)

def _collect_embodiments(ds):
embodiments = set()

if hasattr(ds, "robot_name"):
emb = _normalize_embodiment(ds.robot_name)
if emb is not None:
embodiments.add(emb)

if hasattr(ds, "datasets"):
for sub_ds in ds.datasets.values():
embodiments.update(_collect_embodiments(sub_ds))

def get_zarr_data(ds, col):
if isinstance(ds, (list, tuple)):
for sub_ds in ds:
embodiments.update(_collect_embodiments(sub_ds))

return embodiments

def get_zarr_data(ds, col, embodiment):
if hasattr(ds, "episode_reader"):
# ZarrDataset
ds_emb = _normalize_embodiment(getattr(ds, "robot_name", None))
if ds_emb != embodiment:
return None
if col in ds.episode_reader._store:
return ds.episode_reader._store[col][:]
return None
elif hasattr(ds, "datasets"):
if hasattr(ds, "datasets"):
# MultiDataset wrapper
data_list = []
for d in ds.datasets.values():
res = get_zarr_data(d, col)
res = get_zarr_data(d, col, embodiment)
if res is not None:
data_list.append(res)
if data_list:
return np.concatenate(data_list, axis=0)
if isinstance(ds, (list, tuple)):
data_list = []
for d in ds:
res = get_zarr_data(d, col, embodiment)
if res is not None:
data_list.append(res)
if data_list:
return np.concatenate(data_list, axis=0)
return None

for column in norm_columns:
if not self.is_key_with_embodiment(column, embodiment):
continue
column_name = self.keyname_to_lerobot_key(column, embodiment)
logger.info(f"[NormStats] Processing column={column_name}")
embodiments = _collect_embodiments(dataset)
if not embodiments:
raise ValueError("Could not determine any embodiments from dataset to infer norms.")

column_data = get_zarr_data(dataset, column_name)
for embodiment in sorted(embodiments):
logger.info(
f"[NormStats] Starting norm inference for embodiment={embodiment}, "
f"{len(norm_columns)} columns"
)

if column_data is None:
logger.warning(f"Skipping {column_name}, data not found given dataset type")
continue
for column in norm_columns:
if not self.is_key_with_embodiment(column, embodiment):
continue
column_name = self.keyname_to_lerobot_key(column, embodiment)
logger.info(f"[NormStats] Processing column={column_name}")

if column_data.ndim not in (2, 3):
raise ValueError(
f"Column {column} has shape {column_data.shape}, "
"expected 2 or 3 dims"
)
column_data = get_zarr_data(dataset, column_name, embodiment)

mean = np.mean(column_data, axis=0)
std = np.std(column_data, axis=0)
minv = np.min(column_data, axis=0)
maxv = np.max(column_data, axis=0)
median = np.median(column_data, axis=0)
q1 = np.percentile(column_data, 1, axis=0)
q99 = np.percentile(column_data, 99, axis=0)
if column_data is None:
logger.warning(
f"Skipping {column_name}, data not found for embodiment={embodiment}"
)
continue

self.norm_stats[embodiment][column] = {
"mean": torch.from_numpy(mean).float(),
"std": torch.from_numpy(std).float(),
"min": torch.from_numpy(minv).float(),
"max": torch.from_numpy(maxv).float(),
"median": torch.from_numpy(median).float(),
"quantile_1": torch.from_numpy(q1).float(),
"quantile_99": torch.from_numpy(q99).float(),
}
if column_data.ndim not in (2, 3):
raise ValueError(
f"Column {column} has shape {column_data.shape}, "
"expected 2 or 3 dims"
)

mean = np.mean(column_data, axis=0)
std = np.std(column_data, axis=0)
minv = np.min(column_data, axis=0)
maxv = np.max(column_data, axis=0)
median = np.median(column_data, axis=0)
q1 = np.percentile(column_data, 1, axis=0)
q99 = np.percentile(column_data, 99, axis=0)

self.norm_stats[embodiment][column] = {
"mean": torch.from_numpy(mean).float(),
"std": torch.from_numpy(std).float(),
"min": torch.from_numpy(minv).float(),
"max": torch.from_numpy(maxv).float(),
"median": torch.from_numpy(median).float(),
"quantile_1": torch.from_numpy(q1).float(),
"quantile_99": torch.from_numpy(q99).float(),
}

logger.info("[NormStats] Finished norm inference")

Expand Down
6 changes: 5 additions & 1 deletion egomimic/rldb/zarr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
MultiDataset,
ZarrDataset,
ZarrEpisode,
LocalEpisodeResolver,
S3EpisodeResolver,
)
from egomimic.rldb.zarr.zarr_writer import ZarrWriter
#from egomimic.rldb.zarr.zarr_writer import ZarrWriter

__all__ = [
"EpisodeResolver",
"MultiDataset",
"ZarrDataset",
"ZarrEpisode",
"ZarrWriter",
"LocalEpisodeResolver",
"S3EpisodeResolver",
]
Loading