diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index bb7b08b7..73fe98dd 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -214,12 +214,15 @@ def _get_filtered_paths( filters = dict(filters) if filters is not None else {} engine = create_default_engine() df = episode_table_to_df(engine) - series = pd.Series(filters) - output = df.loc[ - (df[list(filters)] == series).all(axis=1), - ["zarr_processed_path", "episode_hash"], - ] + mask = pd.Series([True] * len(df), index=df.index) + for key, value in filters.items(): + if isinstance(value, str) or not hasattr(value, "__iter__"): + mask &= df[key] == value + else: + mask &= df[key].isin(list(value)) + + output = df.loc[mask, ["zarr_processed_path", "episode_hash"]] before_len = len(output) if debug: @@ -404,8 +407,12 @@ def _local_filters_match(metadata: dict, episode_hash: str, filters: dict) -> bo if meta_value is None: return False - if meta_value != value: - return False + if isinstance(value, str) or not hasattr(value, "__iter__"): + if meta_value != value: + return False + else: + if meta_value not in list(value): + return False return True