From ad0e765af51e0d7f5f9b5406bd5b33de6d3a31da Mon Sep 17 00:00:00 2001 From: Simar Kareer Date: Sun, 8 Mar 2026 17:50:44 -0400 Subject: [PATCH] advanced filtering --- README.md | 6 +- egomimic/hydra_configs/data/aria.yaml | 8 +- egomimic/hydra_configs/data/eva.yaml | 8 +- .../hydra_configs/data/eva_human_cotrain.yaml | 18 ++- egomimic/hydra_configs/data/mecka.yaml | 8 +- egomimic/hydra_configs/data/scale.yaml | 8 +- egomimic/rldb/filters.py | 36 +++++ egomimic/rldb/zarr/test_dataset_filter.py | 139 ++++++++++++++++++ egomimic/rldb/zarr/zarr_dataset_multi.py | 134 +++++++++++------ egomimic/scripts/data_download/sync_s3.py | 40 ++++- egomimic/scripts/mecka_process/test_data.py | 26 ++-- 11 files changed, 357 insertions(+), 74 deletions(-) create mode 100644 egomimic/rldb/filters.py create mode 100644 egomimic/rldb/zarr/test_dataset_filter.py diff --git a/README.md b/README.md index fede9b7c..157980fb 100644 --- a/README.md +++ b/README.md @@ -87,8 +87,8 @@ While our training pipeline automatically downloads data, you can manually downl For example, to download all our flagship Aria fold clothes data... ``` python egomimic/scripts/data_download/sync_s3.py \ - --local-dir \ - --filters '{"embodiment":"aria", "task": "fold_clothes"}' + --local-dir \ + --filters aria-fold-clothes ``` ### Training @@ -99,4 +99,4 @@ python egomimic/trainHydra.py --config-name=train_zarr For full instructions on training see [``training.md``](./training.md) ### Converting your own data -See [``embodiment_tutorial.ipynb``](./egomimic/scripts/tutorials/embodiment_tutorial.ipynb) as reference to write a conversion script for your own data. \ No newline at end of file +See [``embodiment_tutorial.ipynb``](./egomimic/scripts/tutorials/embodiment_tutorial.ipynb) as reference to write a conversion script for your own data. diff --git a/egomimic/hydra_configs/data/aria.yaml b/egomimic/hydra_configs/data/aria.yaml index e60a92a4..8533f736 100644 --- a/egomimic/hydra_configs/data/aria.yaml +++ b/egomimic/hydra_configs/data/aria.yaml @@ -11,7 +11,9 @@ train_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Aria.get_transform_list filters: - episode_hash: "2025-09-20-17-47-54-000000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'" mode: total valid_datasets: aria_bimanual: @@ -24,7 +26,9 @@ valid_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Aria.get_transform_list filters: - episode_hash: "2025-09-20-17-47-54-000000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'" mode: total train_dataloader_params: aria_bimanual: diff --git a/egomimic/hydra_configs/data/eva.yaml b/egomimic/hydra_configs/data/eva.yaml index 9bf10a3f..6eb332f7 100644 --- a/egomimic/hydra_configs/data/eva.yaml +++ b/egomimic/hydra_configs/data/eva.yaml @@ -10,7 +10,9 @@ train_datasets: transform_list: _target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list filters: - episode_hash: "2025-12-26-18-07-46-296000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'" mode: total valid_datasets: @@ -24,7 +26,9 @@ valid_datasets: transform_list: _target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list filters: - episode_hash: "2025-12-26-18-07-46-296000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'" mode: total train_dataloader_params: diff --git a/egomimic/hydra_configs/data/eva_human_cotrain.yaml b/egomimic/hydra_configs/data/eva_human_cotrain.yaml index cabf760d..4bc2a1d7 100644 --- a/egomimic/hydra_configs/data/eva_human_cotrain.yaml +++ b/egomimic/hydra_configs/data/eva_human_cotrain.yaml @@ -10,7 +10,9 @@ train_datasets: transform_list: _target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list filters: - episode_hash: "2025-12-26-18-07-46-296000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'" mode: total aria_bimanual: _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver @@ -22,7 +24,9 @@ train_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Aria.get_transform_list filters: - episode_hash: "2025-09-20-17-47-54-000000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'" mode: total valid_datasets: eva_bimanual: @@ -35,7 +39,9 @@ valid_datasets: transform_list: _target_: egomimic.rldb.embodiment.eva.Eva.get_transform_list filters: - episode_hash: "2025-12-26-18-07-46-296000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-12-26-18-07-46-296000'" mode: total aria_bimanual: _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset._from_resolver @@ -47,7 +53,9 @@ valid_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Aria.get_transform_list filters: - episode_hash: "2025-09-20-17-47-54-000000" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '2025-09-20-17-47-54-000000'" mode: total train_dataloader_params: eva_bimanual: @@ -62,4 +70,4 @@ valid_dataloader_params: num_workers: 10 aria_bimanual: batch_size: 32 - num_workers: 10 \ No newline at end of file + num_workers: 10 diff --git a/egomimic/hydra_configs/data/mecka.yaml b/egomimic/hydra_configs/data/mecka.yaml index 647d9b7c..d21ff30e 100644 --- a/egomimic/hydra_configs/data/mecka.yaml +++ b/egomimic/hydra_configs/data/mecka.yaml @@ -11,7 +11,9 @@ train_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list filters: - episode_hash: "69199812208123403bbdb24f" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'" mode: total valid_datasets: mecka_bimanual: @@ -24,7 +26,9 @@ valid_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Mecka.get_transform_list filters: - episode_hash: "69199812208123403bbdb24f" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'" mode: total train_dataloader_params: mecka_bimanual: diff --git a/egomimic/hydra_configs/data/scale.yaml b/egomimic/hydra_configs/data/scale.yaml index 44022e7b..c2752558 100644 --- a/egomimic/hydra_configs/data/scale.yaml +++ b/egomimic/hydra_configs/data/scale.yaml @@ -11,7 +11,9 @@ train_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Scale.get_transform_list filters: - episode_hash: "69199812208123403bbdb24f" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'" mode: total valid_datasets: scale_bimanual: @@ -24,7 +26,9 @@ valid_datasets: transform_list: _target_: egomimic.rldb.embodiment.human.Scale.get_transform_list filters: - episode_hash: "69199812208123403bbdb24f" + _target_: egomimic.rldb.filters.DatasetFilter + filter_lambdas: + - "lambda row: row['episode_hash'] == '69199812208123403bbdb24f'" mode: total train_dataloader_params: scale_bimanual: diff --git a/egomimic/rldb/filters.py b/egomimic/rldb/filters.py new file mode 100644 index 00000000..70e73833 --- /dev/null +++ b/egomimic/rldb/filters.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import sys +from collections.abc import Mapping, Sequence +from typing import Any + + +class DatasetFilter: + def __init__(self, filter_lambdas: Sequence[str] | None = None) -> None: + self.filter_lambdas = list(filter_lambdas or []) + self.filters = [] + for expr in self.filter_lambdas: + try: + predicate = eval(expr) + except Exception as exc: + print(f"Invalid filter: {expr}", file=sys.stderr) + raise ValueError(f"Invalid filter: {expr}") from exc + if not callable(predicate): + print(f"Invalid filter: {expr}", file=sys.stderr) + raise ValueError(f"Invalid filter: {expr}") + self.filters.append(predicate) + + def __repr__(self) -> str: + return f"DatasetFilter(filter_lambdas={self.filter_lambdas!r})" + + def matches(self, row: Mapping[str, Any]) -> bool: + row = dict(row) + if row.get("is_deleted", False): + return False + for expr, predicate in zip(self.filter_lambdas, self.filters, strict=True): + result = predicate(row) + if not isinstance(result, bool): + raise TypeError(f"Filter must return bool: {expr}") + if not result: + return False + return True diff --git a/egomimic/rldb/zarr/test_dataset_filter.py b/egomimic/rldb/zarr/test_dataset_filter.py new file mode 100644 index 00000000..cb6fc7ae --- /dev/null +++ b/egomimic/rldb/zarr/test_dataset_filter.py @@ -0,0 +1,139 @@ +import pandas as pd +import pytest +import zarr + +from egomimic.rldb.filters import DatasetFilter +from egomimic.rldb.zarr import zarr_dataset_multi +from egomimic.scripts.data_download.sync_s3 import parse_dataset_filter_key + + +def _write_episode(root, name: str, **attrs) -> None: + group = zarr.open_group(str(root / f"{name}.zarr"), mode="w") + group.attrs.update(attrs) + + +def test_dataset_filter_matches_rows_and_excludes_deleted_by_default() -> None: + filters = DatasetFilter( + filter_lambdas=["lambda row: row['episode_hash'] == 'episode-1'"] + ) + + assert filters.matches({"episode_hash": "episode-1"}) + assert not filters.matches({"episode_hash": "episode-1", "is_deleted": True}) + assert not filters.matches({"episode_hash": "episode-2"}) + + +def test_dataset_filter_empty_list_matches_all_non_deleted_rows() -> None: + filters = DatasetFilter() + + assert filters.matches({"episode_hash": "episode-1"}) + assert not filters.matches({"episode_hash": "episode-1", "is_deleted": True}) + + +def test_dataset_filter_init_rejects_invalid_filter_and_prints_it(capsys) -> None: + with pytest.raises(ValueError, match="Invalid filter"): + DatasetFilter(filter_lambdas=["lambda row:"]) + + captured = capsys.readouterr() + assert "Invalid filter: lambda row:" in captured.err + + +def test_dataset_filter_matches_requires_bool_result() -> None: + filters = DatasetFilter(filter_lambdas=["lambda row: 1"]) + + with pytest.raises(TypeError, match="Filter must return bool"): + filters.matches({"episode_hash": "episode-1"}) + + +def test_s3_resolver_filters_dataframe_with_dataset_filter(monkeypatch) -> None: + df = pd.DataFrame( + [ + { + "episode_hash": "match", + "zarr_processed_path": "s3://rldb/processed/match/", + "task": "fold_clothes", + "robot_name": "aria_bimanual", + "is_deleted": False, + }, + { + "episode_hash": "fallback", + "zarr_processed_path": "s3://rldb/processed/fallback/", + "task": "fold_clothes", + "robot_name": None, + "embodiment": "aria_bimanual", + "is_deleted": False, + }, + { + "episode_hash": "deleted", + "zarr_processed_path": "s3://rldb/processed/deleted/", + "task": "fold_clothes", + "robot_name": "aria_bimanual", + "is_deleted": True, + }, + { + "episode_hash": "empty-path", + "zarr_processed_path": "", + "task": "fold_clothes", + "robot_name": "aria_bimanual", + "is_deleted": False, + }, + ] + ) + monkeypatch.setattr(zarr_dataset_multi, "create_default_engine", lambda: object()) + monkeypatch.setattr(zarr_dataset_multi, "episode_table_to_df", lambda engine: df) + + filters = DatasetFilter( + filter_lambdas=[ + "lambda row: row['robot_name'] == 'aria_bimanual'", + "lambda row: row['task'] == 'fold_clothes'", + ] + ) + + paths = zarr_dataset_multi.S3EpisodeResolver._get_filtered_paths(filters=filters) + + assert paths == [ + ("s3://rldb/processed/match/", "match"), + ("s3://rldb/processed/fallback/", "fallback"), + ] + + +def test_local_resolver_filters_local_metadata_with_dataset_filter(tmp_path) -> None: + _write_episode( + tmp_path, "episode_a", embodiment="aria_bimanual", task="fold_clothes" + ) + _write_episode( + tmp_path, "episode_b", robot_type="aria_bimanual", task="fold_clothes" + ) + _write_episode( + tmp_path, + "episode_c", + robot_name="aria_bimanual", + task="fold_clothes", + is_deleted=True, + ) + _write_episode( + tmp_path, "episode_d", robot_name="eva_bimanual", task="fold_clothes" + ) + + filters = DatasetFilter( + filter_lambdas=["lambda row: row['robot_name'] == 'aria_bimanual'"] + ) + + paths = zarr_dataset_multi.LocalEpisodeResolver._get_local_filtered_paths( + tmp_path, + filters=filters, + ) + + assert [episode_hash for _, episode_hash in paths] == ["episode_a", "episode_b"] + + +def test_sync_s3_parser_accepts_named_filter_key() -> None: + filters = parse_dataset_filter_key("aria-fold-clothes") + + assert isinstance(filters, DatasetFilter) + assert filters.matches({"embodiment": "aria", "task": "fold_clothes"}) + assert not filters.matches({"embodiment": "aria_bimanual", "task": "fold_clothes"}) + + +def test_sync_s3_parser_rejects_unknown_filter_key() -> None: + with pytest.raises(ValueError, match="Available filter keys"): + parse_dataset_filter_key("does-not-exist") diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index bb7b08b7..a3d01b8c 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -27,7 +27,7 @@ import subprocess import tempfile from pathlib import Path -from typing import Iterable +from typing import Any, Iterable, Mapping import numpy as np import pandas as pd @@ -36,6 +36,7 @@ import zarr # from action_chunk_transforms import Transform +from egomimic.rldb.filters import DatasetFilter from egomimic.utils.aws.aws_data_utils import load_env from egomimic.utils.aws.aws_sql import ( create_default_engine, @@ -78,6 +79,62 @@ def split_dataset_names(dataset_names, valid_ratio=0.2, seed=SEED): return train, valid +def _ensure_dataset_filter(filters: DatasetFilter | None) -> DatasetFilter: + if filters is None: + return DatasetFilter() + if isinstance(filters, DatasetFilter): + return filters + raise TypeError( + "filters must be a DatasetFilter or None in the zarr resolver path. " + "Plain dict filters are no longer supported." + ) + + +def _is_missing_filter_value(value: object) -> bool: + if value is None: + return True + if isinstance(value, str): + return value == "" + + try: + missing = pd.isna(value) + except Exception: + return False + + return isinstance(missing, (bool, np.bool_)) and bool(missing) + + +def _first_present(*values: object) -> object | None: + for value in values: + if not _is_missing_filter_value(value): + return value + return None + + +def _normalize_filter_row( + row: Mapping[str, Any], + *, + episode_hash: str | None = None, +) -> dict[str, Any]: + normalized = dict(row) + normalized["episode_hash"] = ( + episode_hash if episode_hash is not None else normalized.get("episode_hash") + ) + + if _is_missing_filter_value(normalized.get("is_deleted")): + normalized["is_deleted"] = False + + robot_name = _first_present( + normalized.get("robot_name"), + normalized.get("robot_type"), + normalized.get("embodiment"), + ) + if robot_name is not None: + normalized["robot_name"] = robot_name + + return normalized + + class EpisodeResolver: """ Base class for episode resolution utilities. @@ -158,19 +215,18 @@ def __init__( def resolve( self, - filters: dict | None = None, + filters: DatasetFilter | None = None, ) -> dict[str, "ZarrDataset"]: """ Outputs a dict of ZarrDatasets with relevant filters. Syncs S3 paths to local_root before indexing. """ - filters = dict(filters) if filters is not None else {} - filters["is_deleted"] = False + filters = _ensure_dataset_filter(filters) if self.folder_path.is_dir(): logger.info(f"Using existing directory: {self.folder_path}") if not self.folder_path.is_dir(): - self.folder_path.mkdir() + self.folder_path.mkdir(parents=True, exist_ok=True) logger.info(f"Filters: {filters}") @@ -197,7 +253,7 @@ def resolve( @staticmethod def _get_filtered_paths( - filters: dict | None = None, debug: bool = False + filters: DatasetFilter | None = None, debug: bool = False ) -> list[tuple[str, str]]: """ Filters episodes from the SQL episode table according to the criteria specified in `filters` @@ -205,21 +261,25 @@ def _get_filtered_paths( have a non-null zarr_processed_path. Args: - filters (dict): Dictionary of filter key-value pairs to apply on the episode table. + filters (DatasetFilter | None): Filter object applied row-by-row to the + episode table. Returns: list[tuple[str, str]]: List of tuples, each containing (zarr_processed_path, episode_hash) for episodes passing the filter criteria. """ - filters = dict(filters) if filters is not None else {} + filters = _ensure_dataset_filter(filters) engine = create_default_engine() df = episode_table_to_df(engine) - series = pd.Series(filters) + if df.empty: + logger.info("Episode table is empty.") + return [] - output = df.loc[ - (df[list(filters)] == series).all(axis=1), - ["zarr_processed_path", "episode_hash"], - ] + mask = df.apply( + lambda row: filters.matches(_normalize_filter_row(row.to_dict())), + axis=1, + ) + output = df.loc[mask, ["zarr_processed_path", "episode_hash"]] before_len = len(output) if debug: @@ -328,7 +388,7 @@ def sync_from_filters( cls, *, bucket_name: str, - filters: dict, + filters: DatasetFilter | None = None, local_dir: Path, numworkers: int = 10, debug: bool = False, @@ -345,6 +405,7 @@ def sync_from_filters( Returns: List[(processed_path, episode_hash)] """ + filters = _ensure_dataset_filter(filters) # 1) Resolve episodes from DB filtered_paths = cls._get_filtered_paths(filters, debug=debug) @@ -384,35 +445,23 @@ def __init__( self.debug = debug @staticmethod - def _local_filters_match(metadata: dict, episode_hash: str, filters: dict) -> bool: - for key, value in filters.items(): - if key == "episode_hash": - if episode_hash != value: - return False - continue - - if key == "robot_name": - meta_value = ( - metadata.get("robot_name") - or metadata.get("robot_type") - or metadata.get("embodiment") - ) - elif key == "is_deleted": - meta_value = metadata.get("is_deleted", False) - else: - meta_value = metadata.get(key) - - if meta_value is None: - return False - if meta_value != value: - return False - - return True + def _local_filters_match( + metadata: dict, + episode_hash: str, + filters: DatasetFilter, + ) -> bool: + return filters.matches( + _normalize_filter_row(metadata, episode_hash=episode_hash) + ) @classmethod def _get_local_filtered_paths( - cls, search_path: Path, filters: dict, debug: bool = False + cls, + search_path: Path, + filters: DatasetFilter | None = None, + debug: bool = False, ): + filters = _ensure_dataset_filter(filters) if not search_path.is_dir(): logger.warning("Local path does not exist: %s", search_path) return [] @@ -444,7 +493,7 @@ def _get_local_filtered_paths( def resolve( self, sync_from_s3=False, - filters: dict | None = None, + filters: DatasetFilter | None = None, ) -> dict[str, "ZarrDataset"]: """ Outputs a dict of ZarrDatasets with relevant filters from local data. @@ -454,8 +503,7 @@ def resolve( "LocalEpisodeResolver does not sync from S3; ignoring sync_from_s3=True." ) - filters = dict(filters) if filters is not None else {} - filters.setdefault("is_deleted", False) + filters = _ensure_dataset_filter(filters) filtered_paths = self._get_local_filtered_paths( self.folder_path, filters, debug=self.debug @@ -560,7 +608,7 @@ def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): # TODO add key_map and transform pass to children sync_from_s3 = kwargs.pop("sync_from_s3", False) - filters = kwargs.pop("filters", {}) or {} + filters = kwargs.pop("filters", None) if isinstance(resolver, LocalEpisodeResolver): resolved = resolver.resolve( diff --git a/egomimic/scripts/data_download/sync_s3.py b/egomimic/scripts/data_download/sync_s3.py index 11caf710..e89c20b0 100644 --- a/egomimic/scripts/data_download/sync_s3.py +++ b/egomimic/scripts/data_download/sync_s3.py @@ -1,27 +1,54 @@ """ Sync EgoVerse data from S3/R2 to a local directory. + +Example: + python egomimic/scripts/data_download/sync_s3.py --local-dir /tmp/egoverse \ + --filters aria-fold-clothes """ import argparse -import json import logging import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent)) +from egomimic.rldb.filters import DatasetFilter from egomimic.rldb.zarr.zarr_dataset_multi import S3EpisodeResolver from egomimic.utils.aws.aws_data_utils import load_env logging.basicConfig(level=logging.INFO, format="%(message)s") +DEFAULT_FILTERS = { + "aria-fold-clothes": DatasetFilter( + filter_lambdas=[ + "lambda row: row.get('embodiment') == 'aria'", + "lambda row: row.get('task') == 'fold_clothes'", + ] + ), +} + + +def parse_dataset_filter_key(filter_key: str) -> DatasetFilter: + try: + return DEFAULT_FILTERS[filter_key] + except KeyError as exc: + raise ValueError( + f"Unknown filter key {filter_key!r}. " + f"Available filter keys: {sorted(DEFAULT_FILTERS)}" + ) from exc + + def main(): parser = argparse.ArgumentParser( description="Sync EgoVerse data from S3/R2 to a local directory." ) parser.add_argument( - "--local-dir", type=str, required=True, help="Local directory to sync into." + "--local-dir", + type=str, + required=True, + help="Local directory to sync into.", ) parser.add_argument( "--workers", type=int, default=128, help="s5cmd parallel workers." @@ -30,13 +57,14 @@ def main(): "--filters", type=str, required=True, - help='JSON dict of SQL filters, e.g. \'{"lab": "mecka"}\' or \'{"episode_hash": "h1"}\'.', + help=( + "Named DatasetFilter preset key. " + f"Available keys: {', '.join(sorted(DEFAULT_FILTERS))}" + ), ) args = parser.parse_args() - filters = json.loads(args.filters) - if not isinstance(filters, dict): - raise ValueError("--filters must be a JSON object (dict).") + filters = parse_dataset_filter_key(args.filters) load_env() S3EpisodeResolver.sync_from_filters( diff --git a/egomimic/scripts/mecka_process/test_data.py b/egomimic/scripts/mecka_process/test_data.py index 4671b9e7..46c0f64e 100644 --- a/egomimic/scripts/mecka_process/test_data.py +++ b/egomimic/scripts/mecka_process/test_data.py @@ -24,9 +24,9 @@ import zarr from tqdm import tqdm +from egomimic.rldb.filters import DatasetFilter from egomimic.rldb.zarr import LocalEpisodeResolver - # ── per-episode worker ──────────────────────────────────────────────────────── @@ -46,20 +46,23 @@ def _scan_episode(ep_path: Path, episode_hash: str) -> dict: # Skip non-numeric or 1-D-only (annotations, jpeg stores) if arr.ndim < 2 or not np.issubdtype(arr.dtype, np.number): continue - data: np.ndarray = arr[:] # read whole array once + data: np.ndarray = arr[:] # read whole array once T = data.shape[0] - flat = data.reshape(T, -1) # (T, features) + flat = data.reshape(T, -1) # (T, features) zero_mask = (flat == 0).all(axis=1) bad = np.where(zero_mask)[0].tolist() if bad: zero_rows[key] = bad - return {"episode_hash": eh, "total_frames": total_frames, - "zero_rows": zero_rows, "error": None} + return { + "episode_hash": eh, + "total_frames": total_frames, + "zero_rows": zero_rows, + "error": None, + } except Exception as e: - return {"episode_hash": eh, "total_frames": 0, - "zero_rows": {}, "error": str(e)} + return {"episode_hash": eh, "total_frames": 0, "zero_rows": {}, "error": str(e)} # ── main ────────────────────────────────────────────────────────────────────── @@ -101,7 +104,10 @@ def main() -> int: # Use the resolver to enumerate only valid, readable zarr stores. print("Resolving valid episodes...") - raw = LocalEpisodeResolver._get_local_filtered_paths(dataset_root, filters={}) + raw = LocalEpisodeResolver._get_local_filtered_paths( + dataset_root, + filters=DatasetFilter(), + ) # raw is a list of (path_str, episode_hash) if not raw: print("No valid zarr episodes found.") @@ -182,7 +188,9 @@ def main() -> int: keys_str = ", ".join(zero_rows.keys()) print(f" {eh}") print(f" keys : {keys_str}") - print(f" frames : {len(all_bad)} {all_bad[:20]}{'...' if len(all_bad) > 20 else ''}") + print( + f" frames : {len(all_bad)} {all_bad[:20]}{'...' if len(all_bad) > 20 else ''}" + ) print("=" * 60) return 1 if (results_with_zeros or scan_errors) else 0