Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Repo Agent Rules

## Shell / Command Execution
to run commands in the interactive shell make sure to source emimic/bin/activate

Apply this before running project Python tooling (for example: `python`, `pytest`, `pip`).

## Model settings
use plan mode for anything except extremely simple tasks
67 changes: 67 additions & 0 deletions egomimic/hydra_configs/data/test_multi_zarr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper
train_datasets:
dataset1:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver
resolver:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver
folder_path: /nethome/paphiwetsa3/flash/datasets/test_zarr/
key_map:
front_img_1: #batch key
key_type: camera_keys # key type
zarr_key: observations.images.front_img_1 # dataset key
right_wrist_img:
key_type: camera_keys
zarr_key: observations.images.right_wrist_img
left_wrist_img:
key_type: camera_keys
zarr_key: observations.images.left_wrist_img
ee_pose:
key_type: proprio_keys
zarr_key: observations.state.ee_pose
horizon: 100
joint_positions:
key_type: proprio_keys
zarr_key: observations.state.joint_positions
horizon: 100
actions_cartesian:
key_type: action_keys
zarr_key: actions_cartesian
horizon: 100
mode: total
valid_datasets:
dataset1:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver
resolver:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver
folder_path: /nethome/paphiwetsa3/flash/datasets/test_zarr/
key_map:
front_img_1: #batch key
key_type: camera_keys # key type
zarr_key: observations.images.front_img_1 # dataset key
right_wrist_img:
key_type: camera_keys
zarr_key: observations.images.right_wrist_img
left_wrist_img:
key_type: camera_keys
zarr_key: observations.images.left_wrist_img
ee_pose:
key_type: proprio_keys
zarr_key: observations.state.ee_pose
horizon: 100
joint_positions:
key_type: proprio_keys
zarr_key: observations.state.joint_positions
horizon: 100
actions_cartesian:
key_type: action_keys
zarr_key: actions_cartesian
horizon: 100
mode: total
train_dataloader_params:
dataset1:
batch_size: 2
num_workers: 10
valid_dataloader_params:
dataset1:
batch_size: 2
num_workers: 10
2 changes: 1 addition & 1 deletion egomimic/pl_utils/pl_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def train_dataloader(self):
iterables = dict()
for dataset_name, dataset in self.train_datasets.items():
dataset_params = self.train_dataloader_params.get(dataset_name, {})
iterables[dataset.embodiment] = DataLoader(
iterables[dataset_name] = DataLoader(
dataset,
shuffle=True,
collate_fn=self.collate_fn,
Expand Down
76 changes: 75 additions & 1 deletion egomimic/rldb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,7 @@ def infer_shapes_from_batch(self, batch):

self.shapes_infered = True

def infer_norm_from_dataset(self, dataset):
def infer_norm_from_dataset_lerobot(self, dataset):
"""
dataset: huggingface dataset backed by pyarrow
returns: dictionary of means and stds for proprio and action keys
Expand Down Expand Up @@ -1289,6 +1289,80 @@ def infer_norm_from_dataset(self, dataset):

logger.info("[NormStats] Finished norm inference")

def infer_norm_from_dataset(self, dataset):
"""
dataset: huggingface dataset or zarr dataset
returns: dictionary of means and stds for proprio and action keys
"""
norm_columns = []

embodiment = dataset.embodiment
if isinstance(embodiment, str):
embodiment = get_embodiment_id(embodiment)

norm_columns.extend(self.keys_of_type("proprio_keys"))
norm_columns.extend(self.keys_of_type("action_keys"))

logger.info(
f"[NormStats] Starting norm inference for embodiment={embodiment}, "
f"{len(norm_columns)} columns"
)

def get_zarr_data(ds, col):
if hasattr(ds, "episode_reader"):
# ZarrDataset
if col in ds.episode_reader._store:
return ds.episode_reader._store[col][:]
return None
elif hasattr(ds, "datasets"):
# MultiDataset wrapper
data_list = []
for d in ds.datasets.values():
res = get_zarr_data(d, col)
if res is not None:
data_list.append(res)
if data_list:
return np.concatenate(data_list, axis=0)
return None

for column in norm_columns:
if not self.is_key_with_embodiment(column, embodiment):
continue
column_name = self.keyname_to_lerobot_key(column, embodiment)
logger.info(f"[NormStats] Processing column={column_name}")

column_data = get_zarr_data(dataset, column_name)

if column_data is None:
logger.warning(f"Skipping {column_name}, data not found given dataset type")
continue

if column_data.ndim not in (2, 3):
raise ValueError(
f"Column {column} has shape {column_data.shape}, "
"expected 2 or 3 dims"
)

mean = np.mean(column_data, axis=0)
std = np.std(column_data, axis=0)
minv = np.min(column_data, axis=0)
maxv = np.max(column_data, axis=0)
median = np.median(column_data, axis=0)
q1 = np.percentile(column_data, 1, axis=0)
q99 = np.percentile(column_data, 99, axis=0)

self.norm_stats[embodiment][column] = {
"mean": torch.from_numpy(mean).float(),
"std": torch.from_numpy(std).float(),
"min": torch.from_numpy(minv).float(),
"max": torch.from_numpy(maxv).float(),
"median": torch.from_numpy(median).float(),
"quantile_1": torch.from_numpy(q1).float(),
"quantile_99": torch.from_numpy(q99).float(),
}

logger.info("[NormStats] Finished norm inference")

def viz_img_key(self):
"""
Get the key that should be used for offline visualization
Expand Down
17 changes: 17 additions & 0 deletions egomimic/rldb/zarr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Zarr-based dataset implementations for EgoVerse.
"""

from egomimic.rldb.zarr.zarr_dataset_multi import (
EpisodeResolver,
MultiDataset,
ZarrDataset,
ZarrEpisode,
)

__all__ = [
"EpisodeResolver",
"MultiDataset",
"ZarrDataset",
"ZarrEpisode",
]
Loading