Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ While our training pipeline automatically downloads data, you can manually downl
For example, to download all our flagship Aria fold clothes data...
```
python egomimic/scripts/data_download/sync_s3.py \
--local-dir <local director> \
--filters '{"embodiment":"aria", "task": "fold_clothes"}'
--local-dir <local directory> \
--filters aria-fold-clothes
```

### Training
Expand All @@ -99,4 +99,4 @@ python egomimic/trainHydra.py --config-name=train_zarr
For full instructions on training see [``training.md``](./training.md)

### Converting your own data
See [``embodiment_tutorial.ipynb``](./egomimic/scripts/tutorials/embodiment_tutorial.ipynb) as reference to write a conversion script for your own data.
See [``embodiment_tutorial.ipynb``](./egomimic/scripts/tutorials/embodiment_tutorial.ipynb) as reference to write a conversion script for your own data.
8 changes: 6 additions & 2 deletions egomimic/hydra_configs/data/aria.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ train_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
filters:
episode_hash: "2025-09-20-17-47-54-000000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'"
mode: total
valid_datasets:
aria_bimanual:
Expand All @@ -24,7 +26,9 @@ valid_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
filters:
episode_hash: "2025-09-20-17-47-54-000000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'"
mode: total
train_dataloader_params:
aria_bimanual:
Expand Down
8 changes: 6 additions & 2 deletions egomimic/hydra_configs/data/eva.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ train_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list
filters:
episode_hash: "2025-12-26-18-07-46-296000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'"
mode: total

valid_datasets:
Expand All @@ -24,7 +26,9 @@ valid_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list
filters:
episode_hash: "2025-12-26-18-07-46-296000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'"
mode: total

train_dataloader_params:
Expand Down
18 changes: 13 additions & 5 deletions egomimic/hydra_configs/data/eva_human_cotrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ train_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list
filters:
episode_hash: "2025-12-26-18-07-46-296000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'"
mode: total
aria_bimanual:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver
Expand All @@ -22,7 +24,9 @@ train_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
filters:
episode_hash: "2025-09-20-17-47-54-000000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'"
mode: total
valid_datasets:
eva_bimanual:
Expand All @@ -35,7 +39,9 @@ valid_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list
filters:
episode_hash: "2025-12-26-18-07-46-296000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'"
mode: total
aria_bimanual:
_target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver
Expand All @@ -47,7 +53,9 @@ valid_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Aria.get_transform_list
filters:
episode_hash: "2025-09-20-17-47-54-000000"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'"
mode: total
train_dataloader_params:
eva_bimanual:
Expand All @@ -62,4 +70,4 @@ valid_dataloader_params:
num_workers: 10
aria_bimanual:
batch_size: 32
num_workers: 10
num_workers: 10
8 changes: 6 additions & 2 deletions egomimic/hydra_configs/data/mecka.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ train_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list
filters:
episode_hash: "69199812208123403bbdb24f"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'"
mode: total
valid_datasets:
mecka_bimanual:
Expand All @@ -24,7 +26,9 @@ valid_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list
filters:
episode_hash: "69199812208123403bbdb24f"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'"
mode: total
train_dataloader_params:
mecka_bimanual:
Expand Down
8 changes: 6 additions & 2 deletions egomimic/hydra_configs/data/scale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ train_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Scale.get_transform_list
filters:
episode_hash: "69199812208123403bbdb24f"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'"
mode: total
valid_datasets:
scale_bimanual:
Expand All @@ -24,7 +26,9 @@ valid_datasets:
transform_list:
_target_: egomimic.rldb.embodiment.human.Scale.get_transform_list
filters:
episode_hash: "69199812208123403bbdb24f"
_target_: egomimic.rldb.filters.DatasetFilter
filter_lambdas:
- "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'"
mode: total
train_dataloader_params:
scale_bimanual:
Expand Down
36 changes: 36 additions & 0 deletions egomimic/rldb/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import sys
from collections.abc import Mapping, Sequence
from typing import Any


class DatasetFilter:
def __init__(self, filter_lambdas: Sequence[str] | None = None) -> None:
self.filter_lambdas = list(filter_lambdas or [])
self.filters = []
for expr in self.filter_lambdas:
try:
predicate = eval(expr)
except Exception as exc:
print(f"Invalid filter: {expr}", file=sys.stderr)
raise ValueError(f"Invalid filter: {expr}") from exc
if not callable(predicate):
print(f"Invalid filter: {expr}", file=sys.stderr)
raise ValueError(f"Invalid filter: {expr}")
self.filters.append(predicate)

def __repr__(self) -> str:
return f"DatasetFilter(filter_lambdas={self.filter_lambdas!r})"

def matches(self, row: Mapping[str, Any]) -> bool:
row = dict(row)
if row.get("is_deleted", False):
return False
for expr, predicate in zip(self.filter_lambdas, self.filters, strict=True):
result = predicate(row)
if not isinstance(result, bool):
raise TypeError(f"Filter must return bool: {expr}")
if not result:
return False
return True
139 changes: 139 additions & 0 deletions egomimic/rldb/zarr/test_dataset_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import pandas as pd
import pytest
import zarr

from egomimic.rldb.filters import DatasetFilter
from egomimic.rldb.zarr import zarr_dataset_multi
from egomimic.scripts.data_download.sync_s3 import parse_dataset_filter_key


def _write_episode(root, name: str, **attrs) -> None:
group = zarr.open_group(str(root / f"{name}.zarr"), mode="w")
group.attrs.update(attrs)


def test_dataset_filter_matches_rows_and_excludes_deleted_by_default() -> None:
filters = DatasetFilter(
filter_lambdas=["lambda row: row['episode_hash'] == 'episode-1'"]
)

assert filters.matches({"episode_hash": "episode-1"})
assert not filters.matches({"episode_hash": "episode-1", "is_deleted": True})
assert not filters.matches({"episode_hash": "episode-2"})


def test_dataset_filter_empty_list_matches_all_non_deleted_rows() -> None:
filters = DatasetFilter()

assert filters.matches({"episode_hash": "episode-1"})
assert not filters.matches({"episode_hash": "episode-1", "is_deleted": True})


def test_dataset_filter_init_rejects_invalid_filter_and_prints_it(capsys) -> None:
with pytest.raises(ValueError, match="Invalid filter"):
DatasetFilter(filter_lambdas=["lambda row:"])

captured = capsys.readouterr()
assert "Invalid filter: lambda row:" in captured.err


def test_dataset_filter_matches_requires_bool_result() -> None:
filters = DatasetFilter(filter_lambdas=["lambda row: 1"])

with pytest.raises(TypeError, match="Filter must return bool"):
filters.matches({"episode_hash": "episode-1"})


def test_s3_resolver_filters_dataframe_with_dataset_filter(monkeypatch) -> None:
df = pd.DataFrame(
[
{
"episode_hash": "match",
"zarr_processed_path": "s3://rldb/processed/match/",
"task": "fold_clothes",
"robot_name": "aria_bimanual",
"is_deleted": False,
},
{
"episode_hash": "fallback",
"zarr_processed_path": "s3://rldb/processed/fallback/",
"task": "fold_clothes",
"robot_name": None,
"embodiment": "aria_bimanual",
"is_deleted": False,
},
{
"episode_hash": "deleted",
"zarr_processed_path": "s3://rldb/processed/deleted/",
"task": "fold_clothes",
"robot_name": "aria_bimanual",
"is_deleted": True,
},
{
"episode_hash": "empty-path",
"zarr_processed_path": "",
"task": "fold_clothes",
"robot_name": "aria_bimanual",
"is_deleted": False,
},
]
)
monkeypatch.setattr(zarr_dataset_multi, "create_default_engine", lambda: object())
monkeypatch.setattr(zarr_dataset_multi, "episode_table_to_df", lambda engine: df)

filters = DatasetFilter(
filter_lambdas=[
"lambda row: row['robot_name'] == 'aria_bimanual'",
"lambda row: row['task'] == 'fold_clothes'",
]
)

paths = zarr_dataset_multi.S3EpisodeResolver._get_filtered_paths(filters=filters)

assert paths == [
("s3://rldb/processed/match/", "match"),
("s3://rldb/processed/fallback/", "fallback"),
]


def test_local_resolver_filters_local_metadata_with_dataset_filter(tmp_path) -> None:
_write_episode(
tmp_path, "episode_a", embodiment="aria_bimanual", task="fold_clothes"
)
_write_episode(
tmp_path, "episode_b", robot_type="aria_bimanual", task="fold_clothes"
)
_write_episode(
tmp_path,
"episode_c",
robot_name="aria_bimanual",
task="fold_clothes",
is_deleted=True,
)
_write_episode(
tmp_path, "episode_d", robot_name="eva_bimanual", task="fold_clothes"
)

filters = DatasetFilter(
filter_lambdas=["lambda row: row['robot_name'] == 'aria_bimanual'"]
)

paths = zarr_dataset_multi.LocalEpisodeResolver._get_local_filtered_paths(
tmp_path,
filters=filters,
)

assert [episode_hash for _, episode_hash in paths] == ["episode_a", "episode_b"]


def test_sync_s3_parser_accepts_named_filter_key() -> None:
filters = parse_dataset_filter_key("aria-fold-clothes")

assert isinstance(filters, DatasetFilter)
assert filters.matches({"embodiment": "aria", "task": "fold_clothes"})
assert not filters.matches({"embodiment": "aria_bimanual", "task": "fold_clothes"})


def test_sync_s3_parser_rejects_unknown_filter_key() -> None:
with pytest.raises(ValueError, match="Available filter keys"):
parse_dataset_filter_key("does-not-exist")
Loading