Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 69 additions & 56 deletions egomimic/rldb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pprint import pprint
import random
import shutil
import psutil

from datetime import datetime, timezone
from enum import Enum
from multiprocessing.dummy import connection
Expand Down Expand Up @@ -262,6 +264,14 @@ def __init__(
f"slow_down_ac_keys must be str, sequence, or None; got {type(raw_keys)}"
)

annotation_path = Path(root) / "annotations"
if annotation_path.is_dir():
self.annotations = AnnotationLoader(root=root)
self.annotation_df = self.annotations.df
else:
self.annotations = None
self.annotation_df = None

if mode == "train":
super().__init__(
repo_id=repo_id,
Expand Down Expand Up @@ -335,62 +345,62 @@ def __getitem__(self, idx):
item[key] = self._slow_down_sequence(item[key])

ep_idx = int(item["episode_index"])
frame_idx = (
self.sampled_indices[idx] if self.sampled_indices is not None else idx
)

frame_item = self.hf_dataset[frame_idx]
frame_time = float(frame_item["timestamp"])

frame = self.sampled_indices[idx] if self.sampled_indices is not None else idx
frame_item = self.hf_dataset[frame]
frame_item["annotations"] = self._get_frame_annotation(
episode_idx=ep_idx,
frame_time=frame_time,
)

# Check if annotations directory exists before trying to load annotations
annotation_path = Path(self.root) / "annotations"
if not annotation_path.is_dir():
# No annotations available, set empty and return early
frame_item["annotations"] = ""
return frame_item
return frame_item

# Load annotations only if they exist
annotations = AnnotationLoader(root=self.root)
df = annotations.df

current_ts = float(frame_item["timestamp"])
fps = float(self.fps)
frame_duration = 1 / fps

# Use frame start time for annotation lookup to avoid missing boundary annotations
frame_time = current_ts
# print(f"Frame {frame}, ts={current_ts}, duration={frame_duration}, episode {ep_idx}")
def _get_frame_annotation(
self,
episode_idx: int,
frame_time: float,
) -> str:
"""
Return the annotation string for a given episode index and timestamp.
Returns empty string if annotations are unavailable or no match is found.
"""
if self.annotation_df is None:
return ""

df_episode = df.loc[df["idx"].astype(int) == ep_idx]
df_episode = self.annotation_df.loc[
self.annotation_df["idx"].astype(int) == episode_idx
]

if df_episode.empty:
logger.debug("No annotations for episode %s", ep_idx)
frame_item["annotations"] = ""
return frame_item

# print(df_episode.head())
return ""

frame_annotations = df_episode[
# Active annotation
active = df_episode[
(df_episode["start_time"] <= frame_time)
& (df_episode["end_time"] >= frame_time)
]

if frame_annotations.empty:
next_ann = df_episode[df_episode["start_time"] > frame_time]
if next_ann.empty:
annotation = df_episode.tail(1)["Labels"].iloc[0]
frame_item["annotations"] = annotation
return frame_item
else:
next_pos = df_episode.index.get_loc(next_ann.index[0])
prev_pos = next_pos - 1
if prev_pos >= 0:
annotation = df_episode.iloc[prev_pos]["Labels"]
else:
annotation = ""
frame_item["annotations"] = annotation
return frame_item
else:
annotation = frame_annotations["Labels"].iloc[0]
frame_item["annotations"] = annotation
return frame_item
if not active.empty:
return active["Labels"].iloc[0]

# Fallback: previous annotation
future = df_episode[df_episode["start_time"] > frame_time]
if future.empty:
return df_episode.tail(1)["Labels"].iloc[0]

next_pos = df_episode.index.get_loc(future.index[0])
prev_pos = next_pos - 1
if prev_pos >= 0:
return df_episode.iloc[prev_pos]["Labels"]

return ""


def _slow_down_sequence(self, seq, rot_spec=None):
"""
Expand Down Expand Up @@ -674,13 +684,12 @@ def __init__(
local_files_only=True,
key_map=None,
valid_ratio=0.2,
temp_root="/coc/flash7/scratch/egoverseS3Dataset", # "/coc/flash7/scratch/rldb_temp"
temp_root="/coc/flash7/scratch/egoverseS3Dataset/S3_rldb_data", # "/coc/flash7/scratch/rldb_temp"
cache_root="/coc/flash7/scratch/.cache",
filters={},
debug=False,
**kwargs,
):
temp_root += "/S3_rldb_data"
filters["robot_name"] = embodiment
filters["is_deleted"] = False

Expand All @@ -700,12 +709,18 @@ def __init__(
temp_root = "/" + temp_root
temp_root = Path(temp_root)

if temp_root.is_dir():
logger.info(f"Using existing temp_root directory: {temp_root}")
if not temp_root.is_dir():
temp_root.mkdir()

logger.info(f"Summary of S3RLDBDataset: {temp_root}")
logger.info(f"Bucket Name: {bucket_name}")
logger.info(f"Filters: {filters}")
logger.info(f"Local Files Only: {local_files_only}")
logger.info(f"Percent: {percent}")
logger.info(f"Valid Ratio: {valid_ratio}")
logger.info(f"Debug: {debug}")
logger.info(f"kwargs: {kwargs}")

datasets = {}
skipped = []

Expand All @@ -720,7 +735,7 @@ def __init__(
valid_collection_names = set()
for _, hashes in filtered_paths:
valid_collection_names.add(hashes)

max_workers = int(os.environ.get("RLDB_LOAD_WORKERS", "10"))

datasets, skipped = self._load_rldb_datasets_parallel(
Expand Down Expand Up @@ -887,12 +902,12 @@ def _submit_arg(p: Path):

skipped.append(repo_id)

if reason == "not_in_filtered_paths":
logger.warning(f"Skipping {repo_id}: not in filtered S3 paths")
elif reason and reason.startswith("embodiment_mismatch"):
logger.warning(f"Skipping {repo_id}: {reason}")
else:
logger.error(f"Failed to load {repo_id} as RLDBDataset:\n{err}")
# if reason == "not_in_filtered_paths":
# logger.warning(f"Skipping {repo_id}: not in filtered S3 paths")
# elif reason and reason.startswith("embodiment_mismatch"):
# logger.warning(f"Skipping {repo_id}: {reason}")
# else:
# logger.error(f"Failed to load {repo_id} as RLDBDataset:\n{err}")

return datasets, skipped

Expand Down Expand Up @@ -1241,15 +1256,13 @@ def infer_norm_from_dataset(self, dataset):
for column in norm_columns:
if not self.is_key_with_embodiment(column, embodiment):
continue

column_name = self.keyname_to_lerobot_key(column, embodiment)
logger.info(f"[NormStats] Processing column={column_name}")

# Arrow → NumPy (fast path, preserves shape)
column_data = dataset.hf_dataset.with_format(
"numpy", columns=[column_name]
)[:][column_name]

if column_data.ndim not in (2, 3):
raise ValueError(
f"Column {column} has shape {column_data.shape}, "
Expand Down