diff --git a/egomimic/rldb/utils.py b/egomimic/rldb/utils.py index 0d0b9518..d09cf17b 100644 --- a/egomimic/rldb/utils.py +++ b/egomimic/rldb/utils.py @@ -4,6 +4,8 @@ from pprint import pprint import random import shutil +import psutil + from datetime import datetime, timezone from enum import Enum from multiprocessing.dummy import connection @@ -262,6 +264,14 @@ def __init__( f"slow_down_ac_keys must be str, sequence, or None; got {type(raw_keys)}" ) + annotation_path = Path(root) / "annotations" + if annotation_path.is_dir(): + self.annotations = AnnotationLoader(root=root) + self.annotation_df = self.annotations.df + else: + self.annotations = None + self.annotation_df = None + if mode == "train": super().__init__( repo_id=repo_id, @@ -335,62 +345,62 @@ def __getitem__(self, idx): item[key] = self._slow_down_sequence(item[key]) ep_idx = int(item["episode_index"]) + frame_idx = ( + self.sampled_indices[idx] if self.sampled_indices is not None else idx + ) + + frame_item = self.hf_dataset[frame_idx] + frame_time = float(frame_item["timestamp"]) - frame = self.sampled_indices[idx] if self.sampled_indices is not None else idx - frame_item = self.hf_dataset[frame] + frame_item["annotations"] = self._get_frame_annotation( + episode_idx=ep_idx, + frame_time=frame_time, + ) - # Check if annotations directory exists before trying to load annotations - annotation_path = Path(self.root) / "annotations" - if not annotation_path.is_dir(): - # No annotations available, set empty and return early - frame_item["annotations"] = "" - return frame_item + return frame_item - # Load annotations only if they exist - annotations = AnnotationLoader(root=self.root) - df = annotations.df - current_ts = float(frame_item["timestamp"]) - fps = float(self.fps) - frame_duration = 1 / fps - # Use frame start time for annotation lookup to avoid missing boundary annotations - frame_time = current_ts - # print(f"Frame {frame}, ts={current_ts}, duration={frame_duration}, episode {ep_idx}") + def _get_frame_annotation( + self, + episode_idx: int, + frame_time: float, + ) -> str: + """ + Return the annotation string for a given episode index and timestamp. + Returns empty string if annotations are unavailable or no match is found. + """ + if self.annotation_df is None: + return "" - df_episode = df.loc[df["idx"].astype(int) == ep_idx] + df_episode = self.annotation_df.loc[ + self.annotation_df["idx"].astype(int) == episode_idx + ] if df_episode.empty: - logger.debug("No annotations for episode %s", ep_idx) - frame_item["annotations"] = "" - return frame_item - - # print(df_episode.head()) + return "" - frame_annotations = df_episode[ + # Active annotation + active = df_episode[ (df_episode["start_time"] <= frame_time) & (df_episode["end_time"] >= frame_time) ] - if frame_annotations.empty: - next_ann = df_episode[df_episode["start_time"] > frame_time] - if next_ann.empty: - annotation = df_episode.tail(1)["Labels"].iloc[0] - frame_item["annotations"] = annotation - return frame_item - else: - next_pos = df_episode.index.get_loc(next_ann.index[0]) - prev_pos = next_pos - 1 - if prev_pos >= 0: - annotation = df_episode.iloc[prev_pos]["Labels"] - else: - annotation = "" - frame_item["annotations"] = annotation - return frame_item - else: - annotation = frame_annotations["Labels"].iloc[0] - frame_item["annotations"] = annotation - return frame_item + if not active.empty: + return active["Labels"].iloc[0] + + # Fallback: previous annotation + future = df_episode[df_episode["start_time"] > frame_time] + if future.empty: + return df_episode.tail(1)["Labels"].iloc[0] + + next_pos = df_episode.index.get_loc(future.index[0]) + prev_pos = next_pos - 1 + if prev_pos >= 0: + return df_episode.iloc[prev_pos]["Labels"] + + return "" + def _slow_down_sequence(self, seq, rot_spec=None): """ @@ -674,13 +684,12 @@ def __init__( local_files_only=True, key_map=None, valid_ratio=0.2, - temp_root="/coc/flash7/scratch/egoverseS3Dataset", # "/coc/flash7/scratch/rldb_temp" + temp_root="/coc/flash7/scratch/egoverseS3Dataset/S3_rldb_data", # "/coc/flash7/scratch/rldb_temp" cache_root="/coc/flash7/scratch/.cache", filters={}, debug=False, **kwargs, ): - temp_root += "/S3_rldb_data" filters["robot_name"] = embodiment filters["is_deleted"] = False @@ -700,12 +709,18 @@ def __init__( temp_root = "/" + temp_root temp_root = Path(temp_root) - if temp_root.is_dir(): - logger.info(f"Using existing temp_root directory: {temp_root}") if not temp_root.is_dir(): temp_root.mkdir() + logger.info(f"Summary of S3RLDBDataset: {temp_root}") + logger.info(f"Bucket Name: {bucket_name}") logger.info(f"Filters: {filters}") + logger.info(f"Local Files Only: {local_files_only}") + logger.info(f"Percent: {percent}") + logger.info(f"Valid Ratio: {valid_ratio}") + logger.info(f"Debug: {debug}") + logger.info(f"kwargs: {kwargs}") + datasets = {} skipped = [] @@ -720,7 +735,7 @@ def __init__( valid_collection_names = set() for _, hashes in filtered_paths: valid_collection_names.add(hashes) - + max_workers = int(os.environ.get("RLDB_LOAD_WORKERS", "10")) datasets, skipped = self._load_rldb_datasets_parallel( @@ -887,12 +902,12 @@ def _submit_arg(p: Path): skipped.append(repo_id) - if reason == "not_in_filtered_paths": - logger.warning(f"Skipping {repo_id}: not in filtered S3 paths") - elif reason and reason.startswith("embodiment_mismatch"): - logger.warning(f"Skipping {repo_id}: {reason}") - else: - logger.error(f"Failed to load {repo_id} as RLDBDataset:\n{err}") + # if reason == "not_in_filtered_paths": + # logger.warning(f"Skipping {repo_id}: not in filtered S3 paths") + # elif reason and reason.startswith("embodiment_mismatch"): + # logger.warning(f"Skipping {repo_id}: {reason}") + # else: + # logger.error(f"Failed to load {repo_id} as RLDBDataset:\n{err}") return datasets, skipped @@ -1241,7 +1256,6 @@ def infer_norm_from_dataset(self, dataset): for column in norm_columns: if not self.is_key_with_embodiment(column, embodiment): continue - column_name = self.keyname_to_lerobot_key(column, embodiment) logger.info(f"[NormStats] Processing column={column_name}") @@ -1249,7 +1263,6 @@ def infer_norm_from_dataset(self, dataset): column_data = dataset.hf_dataset.with_format( "numpy", columns=[column_name] )[:][column_name] - if column_data.ndim not in (2, 3): raise ValueError( f"Column {column} has shape {column_data.shape}, "