diff --git a/egomimic/rldb/zarr/zarr_writer.py b/egomimic/rldb/zarr/zarr_writer.py index a39cc75d..1ff435b6 100644 --- a/egomimic/rldb/zarr/zarr_writer.py +++ b/egomimic/rldb/zarr/zarr_writer.py @@ -328,7 +328,6 @@ def write( self, numeric_data: dict[str, np.ndarray] | None = None, image_data: dict[str, np.ndarray] | None = None, - pre_encoded_image_data: dict[str, tuple[np.ndarray, list[int]]] | None = None, metadata_override: dict[str, Any] | None = None, ) -> None: """ @@ -339,9 +338,6 @@ def write( All arrays must have same length along axis 0. image_data: Dictionary of image arrays with shape (T, H, W, 3). Images will be JPEG-compressed. - pre_encoded_image_data: Dictionary mapping key to (encoded_array, image_shape). - encoded_array is np.ndarray(dtype=object) of JPEG bytes. - image_shape is [H, W, 3]. Skips internal JPEG encoding. metadata_override: Optional metadata overrides to apply after building metadata. Raises: @@ -350,26 +346,18 @@ def write( """ numeric_data = numeric_data or {} image_data = image_data or {} - pre_encoded_image_data = pre_encoded_image_data or {} - if not numeric_data and not image_data and not pre_encoded_image_data: - raise ValueError( - "Must provide at least one of numeric_data, image_data, " - "or pre_encoded_image_data" - ) + if not numeric_data and not image_data: + raise ValueError("Must provide at least one of numeric_data or image_data") # Infer total_frames from data - all_lengths: list[int] = [] - for arr in numeric_data.values(): - all_lengths.append(len(arr)) - for arr in image_data.values(): + all_lengths = [] + for key, arr in {**numeric_data, **image_data}.items(): all_lengths.append(len(arr)) - for enc_arr, _shape in pre_encoded_image_data.values(): - all_lengths.append(len(enc_arr)) if len(set(all_lengths)) > 1: raise ValueError( - f"Inconsistent frame counts across arrays: {all_lengths}" + f"Inconsistent frame counts across arrays: {dict(zip(numeric_data.keys() | image_data.keys(), all_lengths))}" ) self.total_frames = all_lengths[0] @@ -392,16 +380,10 @@ def write( for key, arr in numeric_data.items(): self._write_numeric_array(store, key, arr, padded_frames) - # Write image arrays (with internal JPEG encoding) + # Write image arrays for key, arr in image_data.items(): self._write_image_array(store, key, arr, padded_frames) - # Write pre-encoded image arrays (skip JPEG encoding) - for key, (enc_arr, img_shape) in pre_encoded_image_data.items(): - self._write_pre_encoded_image_array( - store, key, enc_arr, img_shape, padded_frames - ) - # Write language annotations if provided if self.annotations is not None: self._write_annotations(store, self.annotations) @@ -562,61 +544,6 @@ def _write_image_array( "names": ["height", "width", "channel"], } - def _write_pre_encoded_image_array( - self, - store: zarr.Group, - key: str, - encoded_arr: np.ndarray, - image_shape: list[int], - padded_frames: int, - ) -> None: - """ - Write already-JPEG-encoded image data to the Zarr store. - - Args: - store: Zarr group to write to. - key: Array key name. - encoded_arr: Object array of JPEG bytes with shape (T,). - image_shape: Original image dimensions [H, W, 3]. - padded_frames: Target frame count after padding. - """ - num_frames = len(encoded_arr) - - if padded_frames > num_frames: - padded = np.empty((padded_frames,), dtype=object) - padded[:num_frames] = encoded_arr - last_jpeg = encoded_arr[-1] - for i in range(num_frames, padded_frames): - padded[i] = last_jpeg - encoded_arr = padded - - chunk_shape = (1,) - - if self.enable_sharding: - shard_shape = encoded_arr.shape - store.create_array( - key, - shape=encoded_arr.shape, - chunks=chunk_shape, - shards=shard_shape, - dtype=VariableLengthBytes(), - ) - else: - store.create_array( - key, - shape=encoded_arr.shape, - chunks=chunk_shape, - dtype=VariableLengthBytes(), - ) - - store[key][:] = encoded_arr - - self._features[key] = { - "dtype": "jpeg", - "shape": image_shape, - "names": ["height", "width", "channel"], - } - def _write_annotations( self, store: zarr.Group, annotations: list[tuple[str, int, int]] ) -> None: @@ -703,7 +630,6 @@ def create_and_write( episode_path: str | Path, numeric_data: dict[str, np.ndarray] | None = None, image_data: dict[str, np.ndarray] | None = None, - pre_encoded_image_data: dict[str, tuple[np.ndarray, list[int]]] | None = None, embodiment: str = "", fps: int = 30, task: str = "", @@ -719,8 +645,6 @@ def create_and_write( episode_path: Path to episode .zarr directory. numeric_data: Dictionary of numeric arrays (state, actions, etc.). image_data: Dictionary of image arrays with shape (T, H, W, 3). - pre_encoded_image_data: Dict mapping key to (encoded_array, image_shape). - Skips internal JPEG encoding for these keys. embodiment: Robot type identifier. fps: Frames per second (default: 30). task: Task description. @@ -733,8 +657,9 @@ def create_and_write( Path to created episode. Raises: - ValueError: If no data is provided. + ValueError: If neither numeric_data nor image_data are provided. """ + # Create writer writer = ZarrWriter( episode_path=episode_path, embodiment=embodiment, @@ -745,10 +670,10 @@ def create_and_write( enable_sharding=enable_sharding, ) + # Write data writer.write( numeric_data=numeric_data, image_data=image_data, - pre_encoded_image_data=pre_encoded_image_data, metadata_override=metadata_override, ) diff --git a/external/scale/scripts/scale_api.py b/external/scale/scripts/scale_api.py deleted file mode 100644 index 39ee252f..00000000 --- a/external/scale/scripts/scale_api.py +++ /dev/null @@ -1,170 +0,0 @@ -""" -Scale API interactions, file downloading, and SFS data loading. -""" - -import json -import os -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from typing import Any, Dict, Optional -from urllib.parse import urlparse - -import requests -from requests.auth import HTTPBasicAuth -from scaleapi import ScaleClient -from scale_sensor_fusion_io.loaders import SFSLoader - -# --------------------------------------------------------------------------- -# Scale API configuration -# --------------------------------------------------------------------------- - -API_KEY = os.environ.get("SCALE_API_KEY", "") -if not API_KEY: - raise ValueError("SCALE_API_KEY environment variable must be set") - -client = ScaleClient(API_KEY) -auth = HTTPBasicAuth(API_KEY, "") - - -# --------------------------------------------------------------------------- -# Task metadata -# --------------------------------------------------------------------------- - - -def get_simple_response_dict_egocentric(task_id: str) -> Optional[Dict[str, Any]]: - """Get URLs for annotations, SFS, and video streams from a Scale task. - - Also returns task metadata like customerId for SQL registration. - """ - try: - task = client.get_task(task_id) - resp = task.response - - if hasattr(task, "as_dict"): - task_data = task.as_dict() - else: - task_data = task.__dict__ - - response_dict = { - "annotations_url": resp["annotations"]["url"], - "sfs_url": resp["full_recording"]["sfs_url"], - "customer_id": task_data.get("customerId", ""), - "project": task_data.get("project", ""), - "batch_id": task_data.get("batchId", ""), - } - - for video in resp["full_recording"]["video_urls"]: - if video["sensor_id"] == "left": - response_dict["left_rectified"] = video["rgb_url"] - else: - response_dict["right_rectified"] = video["rgb_url"] - - return response_dict - - except Exception as e: - print(f"Error retrieving task {task_id}: {e}") - return None - - -# --------------------------------------------------------------------------- -# File download -# --------------------------------------------------------------------------- - - -def download_file_in_chunks(url: str, output_path: str, chunk_size: int = 8192) -> str: - """Download a file in streaming chunks.""" - response = requests.get(url, stream=True) - response.raise_for_status() - - with open(output_path, "wb") as f: - for chunk in response.iter_content(chunk_size=chunk_size): - f.write(chunk) - - return output_path - - -def download_from_simple_response_dict( - task_output_path: str, - simple_response_dict: Dict[str, str], - verbose: bool = False, -) -> Dict[str, str]: - """Download all files from a response dictionary concurrently. Returns local paths.""" - local_path_dict = {} - to_download: list[tuple[str, str, str]] = [] - - url_keys = {"annotations_url", "sfs_url", "left_rectified", "right_rectified"} - - for key, url in simple_response_dict.items(): - if key not in url_keys: - continue - - parsed = urlparse(url) - file_extension = Path(parsed.path).suffix - key_cleaned = key.replace("_url", "") - local_file_path = os.path.join(task_output_path, key_cleaned + file_extension) - local_path_dict[key_cleaned] = local_file_path - - if os.path.exists(local_file_path): - continue - - if verbose: - print(f"Queued download: {key_cleaned}") - to_download.append((url, local_file_path, key_cleaned)) - - if to_download: - with ThreadPoolExecutor(max_workers=len(to_download)) as pool: - futures = { - pool.submit(download_file_in_chunks, url, path): name - for url, path, name in to_download - } - for future in as_completed(futures): - name = futures[future] - try: - future.result() - except Exception as e: - print(f"Error downloading {name}: {e}") - - return local_path_dict - - -# --------------------------------------------------------------------------- -# SFS / annotation file loading -# --------------------------------------------------------------------------- - - -def load_scene(file_path: str) -> Optional[Dict[str, Any]]: - """Load an SFS file.""" - if not os.path.exists(file_path): - return None - - try: - loader = SFSLoader(file_path) - return loader.load_unsafe() - except Exception: - return None - - -def load_annotation_file(file_path: str) -> Optional[Dict[str, Any]]: - """Load an annotation JSON file.""" - try: - with open(file_path, "r") as f: - data = f.read().rstrip("\x00") - return json.loads(data) - except Exception: - return None - - -def get_posepath(sfs_data: Dict[str, Any], sensor_id: str) -> Optional[Dict[str, Any]]: - """Get pose path for a sensor.""" - for sensor in sfs_data.get("sensors", []): - if sensor.get("id") == sensor_id: - return sensor.get("poses") - return None - - -def get_intrinsics(sfs_data: Dict[str, Any], sensor_id: str) -> Optional[Dict[str, float]]: - """Get camera intrinsics for a sensor.""" - for sensor in sfs_data.get("sensors", []): - if sensor.get("id") == sensor_id: - return sensor.get("intrinsics") - return None diff --git a/external/scale/scripts/sfsEgoverseUtils.py b/external/scale/scripts/sfsEgoverseUtils.py new file mode 100644 index 00000000..e3018f26 --- /dev/null +++ b/external/scale/scripts/sfsEgoverseUtils.py @@ -0,0 +1,316 @@ +""" +SFS to Egoverse Utilities + +Scale API interactions, file downloading, and SFS data loading. +""" + +import json +import os +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Dict, Any, Optional +from urllib.parse import urlparse + +import boto3 +import numpy as np +import requests +import scipy.interpolate +from requests.auth import HTTPBasicAuth +from scaleapi import ScaleClient +from scale_sensor_fusion_io.loaders import SFSLoader +from sqlalchemy import MetaData, Table, create_engine, insert, inspect +from sqlalchemy.exc import IntegrityError + + +# Scale API Configuration +API_KEY = os.environ.get("SCALE_API_KEY", "") +if not API_KEY: + raise ValueError("SCALE_API_KEY environment variable must be set") + +client = ScaleClient(API_KEY) +auth = HTTPBasicAuth(API_KEY, '') + + +@dataclass +class TableRow: + episode_hash: str + operator: str + lab: str + task: str + embodiment: str + robot_name: str + num_frames: int = -1 # Updateable + task_description: str = "" + scene: str = "" + objects: str = "" + processed_path: str = "" # Updateable + processing_error: str = "" # Updateable + mp4_path: str = "" # Updateable + is_deleted: bool = False + is_eval: bool = False + eval_score: float = -1 + eval_success: bool = True + + +def interpolate_arr(v, seq_length): + """ + v: (B, T, D) + seq_length: int + """ + assert len(v.shape) == 3 + if v.shape[1] == seq_length: + return + + interpolated = [] + for i in range(v.shape[0]): + index = v[i] + + interp = scipy.interpolate.interp1d( + np.linspace(0, 1, index.shape[0]), index, axis=0 + ) + interpolated.append(interp(np.linspace(0, 1, seq_length))) + + return np.array(interpolated) + + +def interpolate_arr_euler(v: np.ndarray, seq_length: int) -> np.ndarray: + """ + Interpolate 6DoF poses (translation + Euler angles in radians), + optionally with a 7th gripper dimension, along the time axis. + + v: (B, T, 6) or (B, T, 7) + [x, y, z, yaw, pitch, roll, (optional) gripper] + """ + assert ( + v.ndim == 3 and v.shape[2] in (6, 7) + ), "Input v must be of shape (B, T, 6) or (B, T, 7)" + B, T, D = v.shape + + new_time = np.linspace(0, 1, seq_length) + old_time = np.linspace(0, 1, T) + + outputs = [] + + for i in range(B): + seq = v[i] # (T, D) + + if np.any(seq >= 1e8): + outputs.append(np.full((seq_length, D), 1e9)) + continue + + trans_seq = seq[:, :3] # x, y, z + rot_seq = seq[:, 3:6] # yaw, pitch, roll + + # Avoid discontinuities in angle interpolation + rot_seq_unwrapped = np.unwrap(rot_seq, axis=0) + + trans_interp_func = scipy.interpolate.interp1d( + old_time, trans_seq, axis=0, kind="linear" + ) + rot_interp_func = scipy.interpolate.interp1d( + old_time, rot_seq_unwrapped, axis=0, kind="linear" + ) + + trans_interp = trans_interp_func(new_time) # (seq_length, 3) + rot_interp = rot_interp_func(new_time) # (seq_length, 3) + + # Wrap back to [-pi, pi) + rot_interp = (rot_interp + np.pi) % (2 * np.pi) - np.pi + + if D == 6: + out_seq = np.concatenate([trans_interp, rot_interp], axis=-1) + else: + grip_seq = seq[:, 6:7] # (T, 1) + grip_interp_func = scipy.interpolate.interp1d( + old_time, grip_seq, axis=0, kind="linear" + ) + grip_interp = grip_interp_func(new_time) # (seq_length, 1) + out_seq = np.concatenate( + [trans_interp, rot_interp, grip_interp], axis=-1 + ) + + outputs.append(out_seq) + + return np.stack(outputs, axis=0) # (B, seq_length, D) + + + +def create_default_engine(): + # Try to get credentials from Secrets Manager if SECRETS_ARN is set + SECRETS_ARN = os.environ.get("SECRETS_ARN") + if SECRETS_ARN: + secrets = boto3.client("secretsmanager") + sec = secrets.get_secret_value(SecretId=SECRETS_ARN)["SecretString"] + cfg = json.loads(sec) + HOST = cfg.get("host", cfg.get("HOST")) + DBNAME = cfg.get("dbname", cfg.get("DBNAME", "appdb")) + USER = cfg.get("username", cfg.get("user", cfg.get("USER"))) + PASSWORD = cfg.get("password", cfg.get("PASSWORD")) + PORT = cfg.get("port", 5432) + else: + # Fallback to hardcoded values for local testing + HOST = "lowuse-pg-east2.claua8sacyu5.us-east-2.rds.amazonaws.com" + DBNAME = "appdb" + USER = "appuser" + PASSWORD = "APPUSER_STRONG_PW" + PORT = 5432 + + # --- 1) connect via SQLAlchemy --- + engine = create_engine( + f"postgresql+psycopg://{USER}:{PASSWORD}@{HOST}:{PORT}/{DBNAME}?sslmode=require", + pool_pre_ping=True, + ) + + # --- 2) list tables in the schema 'app' --- + insp = inspect(engine) + print("Tables in schema 'app':", insp.get_table_names(schema="app")) + + return engine + + +def add_episode(engine, episode) -> bool: + """ + Insert one row into app.episodes. + Raises sqlalchemy.exc.IntegrityError if the row violates a unique/PK constraint. + """ + episodes_tbl = _episodes_table(engine) + row = asdict(episode) + + try: + with engine.begin() as conn: + conn.execute(insert(episodes_tbl).values(**row)) + return True + except IntegrityError as e: + # Duplicate (or other constraint) → surface a clear error + raise RuntimeError(f"Insert failed (likely duplicate episode_hash): {e}") from e + + +def get_simple_response_dict_egocentric(task_id: str) -> Optional[Dict[str, Any]]: + """Get URLs for annotations, SFS, and video streams from a Scale task. + + Also returns task metadata like customerId for SQL registration. + """ + try: + task = client.get_task(task_id) + resp = task.response + + # Get task dict for metadata + if hasattr(task, 'as_dict'): + task_data = task.as_dict() + else: + task_data = task.__dict__ + + response_dict = { + "annotations_url": resp["annotations"]["url"], + "sfs_url": resp["full_recording"]["sfs_url"], + # Task metadata for SQL + "customer_id": task_data.get("customerId", ""), + "project": task_data.get("project", ""), + "batch_id": task_data.get("batchId", ""), + } + + for video in resp["full_recording"]["video_urls"]: + if video["sensor_id"] == "left": + response_dict["left_rectified"] = video["rgb_url"] + else: + response_dict["right_rectified"] = video["rgb_url"] + + return response_dict + + except Exception as e: + print(f"Error retrieving task {task_id}: {e}") + return None + + +def _episodes_table(engine): + md = MetaData() + return Table("episodes", md, autoload_with=engine, schema="app") + + +def download_file_in_chunks(url: str, output_path: str, chunk_size: int = 8192) -> str: + """Download a file in chunks.""" + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(output_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + + return output_path + + +def download_from_simple_response_dict( + task_output_path: str, + simple_response_dict: Dict[str, str], + verbose: bool = False +) -> Dict[str, str]: + """Download all files from a response dictionary. Returns local paths. + + Only processes keys ending with '_url' or 'rectified' (actual download URLs). + Skips metadata fields like customer_id, project, batch_id. + """ + local_path_dict = {} + + # Keys that are actual URLs to download + url_keys = {'annotations_url', 'sfs_url', 'left_rectified', 'right_rectified'} + + for key, url in simple_response_dict.items(): + # Skip non-URL metadata fields + if key not in url_keys: + continue + + parsed = urlparse(url) + file_extension = Path(parsed.path).suffix + key_cleaned = key.replace('_url', '') + local_file_path = os.path.join(task_output_path, key_cleaned + file_extension) + local_path_dict[key_cleaned] = local_file_path + + if os.path.exists(local_file_path): + continue + + if verbose: + print(f"Downloading: {key_cleaned}") + try: + download_file_in_chunks(url, local_file_path) + except Exception as e: + print(f"Error downloading {key}: {e}") + + return local_path_dict + + +def load_scene(file_path: str) -> Optional[Dict[str, Any]]: + """Load an SFS file.""" + if not os.path.exists(file_path): + return None + + try: + loader = SFSLoader(file_path) + return loader.load_unsafe() + except Exception: + return None + + +def load_annotation_file(file_path: str) -> Optional[Dict[str, Any]]: + """Load an annotation JSON file.""" + try: + with open(file_path, 'r') as f: + data = f.read().rstrip('\x00') + return json.loads(data) + except Exception: + return None + + +def get_posepath(sfs_data: Dict[str, Any], sensor_id: str) -> Optional[Dict[str, Any]]: + """Get pose path for a sensor.""" + for sensor in sfs_data.get("sensors", []): + if sensor.get("id") == sensor_id: + return sensor.get("poses") + return None + + +def get_intrinsics(sfs_data: Dict[str, Any], sensor_id: str) -> Optional[Dict[str, float]]: + """Get camera intrinsics for a sensor.""" + for sensor in sfs_data.get("sensors", []): + if sensor.get("id") == sensor_id: + return sensor.get("intrinsics") + return None diff --git a/external/scale/scripts/sfs_data.py b/external/scale/scripts/sfs_data.py deleted file mode 100644 index b1d88c7b..00000000 --- a/external/scale/scripts/sfs_data.py +++ /dev/null @@ -1,321 +0,0 @@ -""" -SFS data structures, extraction, and hand pose geometry. - -Provides: - - Data classes: HandKeypoints, CameraPose, FrameData - - SFSDataExtractor: parses SFS + annotation files into per-frame metadata - - Hand pose computation: palm 6DoF, wrist 6DoF, batch euler-to-quat -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any - -import numpy as np -from scipy.spatial.transform import Rotation as R - -from scale_api import ( - get_intrinsics, - get_posepath, - load_annotation_file, - load_scene, -) - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -MANO_LABELS = [ - "hand_wrist", - "hand_thumb1", "hand_thumb2", "hand_thumb3", "hand_thumb4", - "hand_index1", "hand_index2", "hand_index3", "hand_index4", - "hand_middle1", "hand_middle2", "hand_middle3", "hand_middle4", - "hand_ring1", "hand_ring2", "hand_ring3", "hand_ring4", - "hand_pinky1", "hand_pinky2", "hand_pinky3", "hand_pinky4", -] -PALM_INDICES = [0, 5, 9, 13, 17] -NUM_KEYPOINTS = 21 - -INVALID_VALUE = 1e9 - -# --------------------------------------------------------------------------- -# Data classes -# --------------------------------------------------------------------------- - - -@dataclass -class HandKeypoints: - left: np.ndarray | None = None - right: np.ndarray | None = None - - -@dataclass -class CameraPose: - position: np.ndarray - quaternion: np.ndarray - rotation_matrix: np.ndarray - - @classmethod - def from_pose_array(cls, pose: list[float]) -> CameraPose: - position = np.array(pose[:3], dtype=np.float64) - quaternion = np.array(pose[3:7], dtype=np.float64) - rotation = R.from_quat(quaternion).as_matrix() - return cls(position=position, quaternion=quaternion, rotation_matrix=rotation) - - def get_transform_matrix(self) -> np.ndarray: - t = np.eye(4, dtype=np.float64) - t[:3, :3] = self.rotation_matrix - t[:3, 3] = self.position - return t - - -@dataclass -class FrameData: - frame_index: int - timestamp_us: int - camera_pose: CameraPose - hand_keypoints: HandKeypoints - text_annotations: list[dict[str, Any]] = field(default_factory=list) - subgoal: dict[str, Any] | None = None - collector_issue: dict[str, Any] | None = None - hand_tracking_error: dict[str, Any] | None = None - - -# --------------------------------------------------------------------------- -# SFS data extraction -# --------------------------------------------------------------------------- - - -class SFSDataExtractor: - """Extracts per-frame metadata from SFS + annotation files.""" - - def __init__(self, sfs_path: str, annotation_path: str, video_path: str): - self.video_path = video_path - self.sfs_data = load_scene(sfs_path) - self.annotation_data = load_annotation_file(annotation_path) - - if self.sfs_data is None or self.annotation_data is None: - raise ValueError("Failed to load SFS or annotation data") - - self.camera_sensor_id = "left_rectified" - self.posepath = get_posepath(self.sfs_data, self.camera_sensor_id) - if self.posepath is None: - raise ValueError(f"Missing pose data for {self.camera_sensor_id}") - - self.timestamps = self.posepath.get("timestamps", []) - self.pose_values = self.posepath.get("values", []) - - self._build_keypoint_lookup() - self._build_annotation_lookup() - - def _build_keypoint_lookup(self) -> None: - self.keypoint_paths: dict[str, dict[int, dict[int, Any]]] = {"left": {}, "right": {}} - for annotation in self.annotation_data.get("annotations", []): - if annotation.get("type") != "points": - continue - labels = annotation.get("labels", []) - paths = annotation.get("paths", []) - for i, label in enumerate(labels): - if i >= len(paths): - continue - hand_type = "left" if label.startswith("left_") else "right" if label.startswith("right_") else None - if hand_type is None: - continue - prefix_len = 5 if hand_type == "left" else 6 - keypoint_name = label[prefix_len:] - kp_idx = next((idx for idx, v in enumerate(MANO_LABELS) if v == keypoint_name), None) - if kp_idx is None: - continue - path = paths[i] - values = path.get("values", []) - for ts_idx, ts in enumerate(path.get("timestamps", [])): - self.keypoint_paths[hand_type].setdefault(ts, {}) - if ts_idx < len(values): - self.keypoint_paths[hand_type][ts][kp_idx] = values[ts_idx] - - def _build_annotation_lookup(self) -> None: - self.text_annotations: list[dict] = [] - self.subgoal_annotations: list[dict] = [] - self.collector_issues: list[dict] = [] - self.hand_tracking_errors: list[dict] = [] - self.demonstration_metadata: dict[str, Any] = {} - - for attr in self.annotation_data.get("attributes", []): - values = attr.get("values", []) - if values: - self.demonstration_metadata[attr.get("name", "")] = values[0] - - for annotation in self.annotation_data.get("annotations", []): - if annotation.get("type") != "text_annotation": - continue - label = annotation.get("label", "") - for clip in annotation.get("clips", []): - start_ts = clip.get("timestamp", 0) - end_ts = start_ts + clip.get("duration", 0) - text = clip.get("text", "") - attr_dict = {} - for attr in clip.get("attributes", []): - vals = attr.get("values", []) - if vals: - attr_dict[attr.get("name", "")] = vals[0] - - if label == "Sub-goal": - self.subgoal_annotations.append( - {"start_ts": start_ts, "end_ts": end_ts, "text": text} - ) - elif label == "Collector Issue": - self.collector_issues.append( - {"start_ts": start_ts, "end_ts": end_ts, - "issue_type": attr_dict.get("Collector Quality Issue", "")} - ) - elif label == "Hand Tracking Error": - error_type = attr_dict.get("Hand Tracking Error", text) - hand = attr_dict.get("Hand", "Both") - self.hand_tracking_errors.append( - {"start_ts": start_ts, "end_ts": end_ts, - "error_type": error_type, "hand": hand} - ) - # Promote clip-level "Hand Used" to demonstration metadata - if "Hand Used" in attr_dict and "Hand Used" not in self.demonstration_metadata: - self.demonstration_metadata["Hand Used"] = attr_dict["Hand Used"] - - self.text_annotations.append( - {"label": label, "text": text, "start_ts": start_ts, - "end_ts": end_ts, "attributes": attr_dict} - ) - - def get_hand_keypoints_at_timestamp(self, timestamp: int) -> HandKeypoints: - result = HandKeypoints() - for hand_type in ("left", "right"): - if timestamp not in self.keypoint_paths[hand_type]: - continue - kp_dict = self.keypoint_paths[hand_type][timestamp] - if len(kp_dict) < NUM_KEYPOINTS // 2: - continue - keypoints = np.full((NUM_KEYPOINTS, 3), INVALID_VALUE, dtype=np.float32) - for kp_idx, xyz in kp_dict.items(): - keypoints[kp_idx] = xyz - if hand_type == "left": - result.left = keypoints - else: - result.right = keypoints - return result - - def get_subgoal_at_timestamp(self, timestamp: int) -> dict[str, Any] | None: - for item in self.subgoal_annotations: - if item["start_ts"] <= timestamp <= item["end_ts"]: - return item - return None - - def get_collector_issue_at_timestamp(self, timestamp: int) -> dict[str, Any] | None: - for item in self.collector_issues: - if item["start_ts"] <= timestamp <= item["end_ts"]: - return item - return None - - def get_hand_tracking_error_at_timestamp(self, timestamp: int) -> dict[str, Any] | None: - for item in self.hand_tracking_errors: - if item["start_ts"] <= timestamp <= item["end_ts"]: - return item - return None - - def get_text_annotations_at_timestamp(self, timestamp: int) -> list[dict[str, Any]]: - return [ - ann for ann in self.text_annotations - if ann["start_ts"] <= timestamp <= ann["end_ts"] - ] - - def extract_all_frames_metadata(self) -> list[FrameData]: - frames = [] - for i, ts in enumerate(self.timestamps): - pose = self.pose_values[i] - frames.append( - FrameData( - frame_index=i, - timestamp_us=ts, - camera_pose=CameraPose.from_pose_array(pose), - hand_keypoints=self.get_hand_keypoints_at_timestamp(ts), - text_annotations=self.get_text_annotations_at_timestamp(ts), - subgoal=self.get_subgoal_at_timestamp(ts), - collector_issue=self.get_collector_issue_at_timestamp(ts), - hand_tracking_error=self.get_hand_tracking_error_at_timestamp(ts), - ) - ) - return frames - - -# --------------------------------------------------------------------------- -# Hand pose geometry -# --------------------------------------------------------------------------- - - -def _batch_euler_to_quat(euler_zyx: np.ndarray) -> np.ndarray: - """(N, 3) euler ZYX -> (N, 4) quaternion wxyz.""" - q_xyzw = R.from_euler("ZYX", euler_zyx, degrees=False).as_quat() - return q_xyzw[..., [3, 0, 1, 2]].astype(np.float32) - - -def batch_pose6_to_pose7(pose6: np.ndarray) -> np.ndarray: - """(N, 6) [xyz ypr] -> (N, 7) [xyz quat_wxyz]. Invalid sentinels -> zeros.""" - N = pose6.shape[0] - out = np.zeros((N, 7), dtype=np.float32) - valid = ~np.any(pose6 >= INVALID_VALUE - 1, axis=1) - if valid.any(): - out[valid, :3] = pose6[valid, :3] - out[valid, 3:] = _batch_euler_to_quat(pose6[valid, 3:6]) - return out - - -def _compute_palm_centroid(keypoints: np.ndarray) -> np.ndarray: - palm_kps = keypoints[PALM_INDICES] - valid_mask = ~np.any(palm_kps >= INVALID_VALUE - 1, axis=1) - if not np.any(valid_mask): - return np.full(3, INVALID_VALUE, dtype=np.float32) - return np.mean(palm_kps[valid_mask], axis=0).astype(np.float32) - - -def _compute_hand_orientation(keypoints: np.ndarray, flip_x: bool = False) -> np.ndarray: - """Hand frame: x=right, y=down (palm normal toward ground), z=forward (toward fingers). - - flip_x=True for the right hand so that x is rightward for both hands. - Shared by both palm and wrist pose computation. - """ - wrist, index1, middle1, pinky1 = keypoints[0], keypoints[5], keypoints[9], keypoints[17] - if any(np.any(kp >= INVALID_VALUE - 1) for kp in (wrist, index1, middle1, pinky1)): - return np.zeros(3, dtype=np.float32) - z_axis = middle1 - wrist - z_axis /= np.linalg.norm(z_axis) + 1e-8 - across = (pinky1 - index1) if flip_x else (index1 - pinky1) - across -= np.dot(across, z_axis) * z_axis - x_axis = across / (np.linalg.norm(across) + 1e-8) - y_axis = np.cross(z_axis, x_axis) - y_axis /= np.linalg.norm(y_axis) + 1e-8 - rot = np.column_stack([x_axis, y_axis, z_axis]) - try: - return R.from_matrix(rot).as_euler("ZYX", degrees=False).astype(np.float32) - except Exception: - return np.zeros(3, dtype=np.float32) - - -def compute_palm_6dof(keypoints: np.ndarray, flip_x: bool = False) -> np.ndarray: - centroid = _compute_palm_centroid(keypoints) - if np.any(centroid >= INVALID_VALUE - 1): - return np.full(6, INVALID_VALUE, dtype=np.float32) - ypr = _compute_hand_orientation(keypoints, flip_x=flip_x) - return np.concatenate([centroid, ypr]).astype(np.float32) - - -def _compute_wrist_position(keypoints: np.ndarray) -> np.ndarray: - wrist = keypoints[0] - if np.any(wrist >= INVALID_VALUE - 1): - return np.full(3, INVALID_VALUE, dtype=np.float32) - return wrist.astype(np.float32) - - -def compute_wrist_6dof(keypoints: np.ndarray, flip_x: bool = False) -> np.ndarray: - wrist_xyz = _compute_wrist_position(keypoints) - if np.any(wrist_xyz >= INVALID_VALUE - 1): - return np.full(6, INVALID_VALUE, dtype=np.float32) - wrist_ypr = _compute_hand_orientation(keypoints, flip_x=flip_x) - return np.concatenate([wrist_xyz, wrist_ypr]).astype(np.float32) diff --git a/external/scale/scripts/sfs_to_egoverse_zarr.py b/external/scale/scripts/sfs_to_egoverse_zarr.py index a37bfed5..057819f1 100644 --- a/external/scale/scripts/sfs_to_egoverse_zarr.py +++ b/external/scale/scripts/sfs_to_egoverse_zarr.py @@ -12,18 +12,6 @@ obs_head_pose (T, 7) xyz + quat(w, x, y, z) images.front_1 (T, H, W, 3) JPEG-compressed by ZarrWriter -Processing pipeline: - 1. Nullify: Hand Tracking Error frames have their affected-hand keypoints - set to None, turning them into gaps. - 2. Interpolation: Short gaps (configurable, default <=15 frames / 0.5s) - in hand keypoints are filled via Akima spline interpolation with - velocity-clamped sanity checking. - 3. Filtering / zero-fill: - - Tracking-error frames still missing after interpolation → dropped - - Missing keypoints without tracking error (single-hand) → zero-filled, kept - - Inactive Time collector issues → dropped - Only contiguous runs of valid frames are kept as sub-episodes. - Usage: python sfs_to_egoverse_zarr.py --task-ids TASK1 TASK2 --output-dir ./zarr_out """ @@ -33,447 +21,351 @@ import argparse import os import shutil -import subprocess import time import traceback -from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any import cv2 import numpy as np -import simplejpeg -from decord import VideoReader, cpu as decord_cpu -from scipy.interpolate import Akima1DInterpolator +from scipy.interpolate import interp1d from scipy.spatial.transform import Rotation as R from egomimic.rldb.zarr.zarr_writer import ZarrWriter -from scale_api import ( +from sfsEgoverseUtils import ( download_from_simple_response_dict, get_intrinsics, + get_posepath, get_simple_response_dict_egocentric, + load_annotation_file, load_scene, ) -from sfs_data import ( - INVALID_VALUE, - FrameData, - SFSDataExtractor, - batch_pose6_to_pose7, - compute_palm_6dof, - compute_wrist_6dof, -) -SUB_EPISODE_LENGTH = 300 -MIN_EPISODE_FRAMES = 10 -IMAGE_SIZE = (640, 480) # (W, H) for cv2.resize -HAND_USED_TO_EMBODIMENT = { - "Right": "scale_right_arm", - "Left": "scale_left_arm", - "Both": "scale_bimanual", -} +MANO_LABELS = [ + "hand_wrist", + "hand_thumb1", "hand_thumb2", "hand_thumb3", "hand_thumb4", + "hand_index1", "hand_index2", "hand_index3", "hand_index4", + "hand_middle1", "hand_middle2", "hand_middle3", "hand_middle4", + "hand_ring1", "hand_ring2", "hand_ring3", "hand_ring4", + "hand_pinky1", "hand_pinky2", "hand_pinky3", "hand_pinky4", +] +PALM_INDICES = [0, 5, 9, 13, 17] +NUM_KEYPOINTS = 21 - -# --------------------------------------------------------------------------- -# Video / image helpers -# --------------------------------------------------------------------------- - - -def _get_video_frame_count(video_path: str) -> int: - """Get frame count without decoding the video.""" - vr = VideoReader(video_path, ctx=decord_cpu()) - return len(vr) - - -def _decode_selected_frames( - video_path: str, - indices: list[int], - chunk_size: int = 500, - resize: tuple[int, int] | None = IMAGE_SIZE, -) -> dict[int, np.ndarray]: - """Batch-decode only the requested frame indices via decord. - - Decodes in chunks and eagerly resizes to *resize* (W, H) to keep - memory usage bounded. Returns a dict mapping frame index to RGB uint8. - """ - if not indices: - return {} - indices_sorted = sorted(set(indices)) - vr = VideoReader(video_path, ctx=decord_cpu()) - max_idx = len(vr) - 1 - valid = [i for i in indices_sorted if i <= max_idx] - if not valid: - return {} - result: dict[int, np.ndarray] = {} - for start in range(0, len(valid), chunk_size): - chunk_indices = valid[start : start + chunk_size] - batch = vr.get_batch(chunk_indices).asnumpy() - for i, t in enumerate(chunk_indices): - frame = batch[i] - if resize is not None: - frame = cv2.resize(frame, resize, interpolation=cv2.INTER_LINEAR) - result[t] = frame - del batch - return result - - -def _resize_and_encode(frame: np.ndarray) -> tuple[tuple[int, ...], bytes]: - """Resize frame to IMAGE_SIZE and JPEG-encode it. GIL-releasing.""" - resized = cv2.resize(frame, IMAGE_SIZE, interpolation=cv2.INTER_LINEAR) - jpeg = simplejpeg.encode_jpeg( - resized, quality=ZarrWriter.JPEG_QUALITY, colorspace="RGB" - ) - return resized.shape, jpeg +INVALID_VALUE = 1e9 +ACTION_WINDOW = 30 +SUB_EPISODE_LENGTH = 300 +IMAGE_SIZE = (640, 480) # (W, H) for cv2.resize -# --------------------------------------------------------------------------- -# Preview MP4 -# --------------------------------------------------------------------------- +def _batch_euler_to_quat(euler_zyx: np.ndarray) -> np.ndarray: + """(N, 3) euler ZYX -> (N, 4) quaternion wxyz.""" + q_xyzw = R.from_euler("ZYX", euler_zyx, degrees=False).as_quat() # scipy: xyzw + return q_xyzw[..., [3, 0, 1, 2]].astype(np.float32) # reorder -> wxyz -def _save_preview_mp4( - image_frames: list[np.ndarray], - output_path: Path, - fps: int = 30, -) -> None: - """Save a half-resolution H.264 preview video via ffmpeg.""" - if not image_frames: - return - H, W = image_frames[0].shape[:2] - out_w, out_h = (W // 2) & ~1, (H // 2) & ~1 - if out_w <= 0 or out_h <= 0: - return - - output_path.parent.mkdir(parents=True, exist_ok=True) - ffmpeg = shutil.which("ffmpeg") - if ffmpeg is None: - print(f" [preview] ffmpeg not found, skipping {output_path.name}") - return - - cmd = [ - ffmpeg, "-y", - "-f", "rawvideo", "-vcodec", "rawvideo", "-pix_fmt", "rgb24", - "-s", f"{out_w}x{out_h}", "-r", str(fps), - "-i", "-", - "-an", "-c:v", "libx264", "-pix_fmt", "yuv420p", - "-profile:v", "baseline", "-level", "3.0", - "-movflags", "+faststart", "-preset", "veryfast", "-crf", "23", - str(output_path), - ] - proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) - for frame in image_frames: - resized = cv2.resize(frame, (out_w, out_h), interpolation=cv2.INTER_AREA) - proc.stdin.write(resized.tobytes()) - proc.stdin.close() - if proc.wait() != 0: - stderr = proc.stderr.read().decode(errors="replace") - print(f" [preview] ffmpeg failed for {output_path.name}: {stderr[:200]}") - else: - print(f" [preview] saved {output_path.name}") +def _batch_pose6_to_pose7(pose6: np.ndarray) -> np.ndarray: + """(N, 6) [xyz ypr] -> (N, 7) [xyz quat_wxyz]. Invalid sentinels → zeros.""" + N = pose6.shape[0] + out = np.zeros((N, 7), dtype=np.float32) + valid = ~np.any(pose6 >= INVALID_VALUE - 1, axis=1) + if valid.any(): + out[valid, :3] = pose6[valid, :3] + out[valid, 3:] = _batch_euler_to_quat(pose6[valid, 3:6]) + return out # --------------------------------------------------------------------------- -# Hand-keypoint interpolation (Akima spline, gap-length-aware) +# Data structures & extraction (unchanged from original) # --------------------------------------------------------------------------- -MAX_INTERP_GAP_FRAMES = 15 # default: only interpolate gaps <= 0.5s @ 30fps -MAX_INTERP_VELOCITY = 2.0 # m/frame; reject implausible interpolated values - - -def _find_gaps(missing_mask: np.ndarray) -> list[tuple[int, int]]: - """Return (start, end) inclusive index pairs for contiguous True runs.""" - gaps = [] - start = None - for i, m in enumerate(missing_mask): - if m: - if start is None: - start = i - else: - if start is not None: - gaps.append((start, i - 1)) - start = None - if start is not None: - gaps.append((start, len(missing_mask) - 1)) - return gaps - - -def _akima_interpolate_keypoints( - keypoints_seq: np.ndarray, - valid_mask: np.ndarray, - gap_start: int, - gap_end: int, -) -> np.ndarray | None: - """Interpolate a (gap_len, K) block of keypoints using Akima splines. - - keypoints_seq: (N, K) array — full sequence of keypoint values - valid_mask: (N,) bool — True where data is real - gap_start, gap_end: inclusive indices of the gap - - Returns (gap_len, K) interpolated values or None if insufficient context. - """ - gap_len = gap_end - gap_start + 1 - K = keypoints_seq.shape[1] - - # Need at least 2 valid points on each side of the gap for Akima - ctx_lo = max(0, gap_start - 10) - ctx_hi = min(len(valid_mask), gap_end + 11) - - before = [i for i in range(ctx_lo, gap_start) if valid_mask[i]] - after = [i for i in range(gap_end + 1, ctx_hi) if valid_mask[i]] +@dataclass +class HandKeypoints: + left: np.ndarray | None = None + right: np.ndarray | None = None + + +@dataclass +class CameraPose: + position: np.ndarray + quaternion: np.ndarray + rotation_matrix: np.ndarray + + @classmethod + def from_pose_array(cls, pose: list[float]) -> CameraPose: + position = np.array(pose[:3], dtype=np.float64) + quaternion = np.array(pose[3:7], dtype=np.float64) + rotation = R.from_quat(quaternion).as_matrix() + return cls(position=position, quaternion=quaternion, rotation_matrix=rotation) + + def get_transform_matrix(self) -> np.ndarray: + t = np.eye(4, dtype=np.float64) + t[:3, :3] = self.rotation_matrix + t[:3, 3] = self.position + return t + + +@dataclass +class FrameData: + frame_index: int + timestamp_us: int + camera_pose: CameraPose + hand_keypoints: HandKeypoints + text_annotations: list[dict[str, Any]] = field(default_factory=list) + subgoal: dict[str, Any] | None = None + collector_issue: dict[str, Any] | None = None + + +class SFSDataExtractor: + """Extracts per-frame metadata from SFS + annotation files.""" + + def __init__(self, sfs_path: str, annotation_path: str, video_path: str): + self.video_path = video_path + self.sfs_data = load_scene(sfs_path) + self.annotation_data = load_annotation_file(annotation_path) + + if self.sfs_data is None or self.annotation_data is None: + raise ValueError("Failed to load SFS or annotation data") + + self.camera_sensor_id = "left_rectified" + self.intrinsics = get_intrinsics(self.sfs_data, self.camera_sensor_id) + self.posepath = get_posepath(self.sfs_data, self.camera_sensor_id) + if self.intrinsics is None or self.posepath is None: + raise ValueError(f"Missing camera data for {self.camera_sensor_id}") + + self.timestamps = self.posepath.get("timestamps", []) + self.pose_values = self.posepath.get("values", []) + + self._build_keypoint_lookup() + self._build_annotation_lookup() + + # -- keypoint lookup (unchanged) -- + def _build_keypoint_lookup(self) -> None: + self.keypoint_paths: dict[str, dict[int, dict[int, Any]]] = {"left": {}, "right": {}} + for annotation in self.annotation_data.get("annotations", []): + if annotation.get("type") != "points": + continue + labels = annotation.get("labels", []) + paths = annotation.get("paths", []) + for i, label in enumerate(labels): + if i >= len(paths): + continue + hand_type = "left" if label.startswith("left_") else "right" if label.startswith("right_") else None + if hand_type is None: + continue + prefix_len = 5 if hand_type == "left" else 6 + keypoint_name = label[prefix_len:] + kp_idx = next((idx for idx, v in enumerate(MANO_LABELS) if v == keypoint_name), None) + if kp_idx is None: + continue + path = paths[i] + values = path.get("values", []) + for ts_idx, ts in enumerate(path.get("timestamps", [])): + self.keypoint_paths[hand_type].setdefault(ts, {}) + if ts_idx < len(values): + self.keypoint_paths[hand_type][ts][kp_idx] = values[ts_idx] + + def _build_annotation_lookup(self) -> None: + self.text_annotations: list[dict] = [] + self.subgoal_annotations: list[dict] = [] + self.collector_issues: list[dict] = [] + self.demonstration_metadata: dict[str, Any] = {} + + for attr in self.annotation_data.get("attributes", []): + values = attr.get("values", []) + if values: + self.demonstration_metadata[attr.get("name", "")] = values[0] + + for annotation in self.annotation_data.get("annotations", []): + if annotation.get("type") != "text_annotation": + continue + label = annotation.get("label", "") + for clip in annotation.get("clips", []): + start_ts = clip.get("timestamp", 0) + end_ts = start_ts + clip.get("duration", 0) + text = clip.get("text", "") + attr_dict = {} + for attr in clip.get("attributes", []): + vals = attr.get("values", []) + if vals: + attr_dict[attr.get("name", "")] = vals[0] + + if label == "Sub-goal": + self.subgoal_annotations.append( + {"start_ts": start_ts, "end_ts": end_ts, "text": text} + ) + elif label == "Collector Issue": + self.collector_issues.append( + {"start_ts": start_ts, "end_ts": end_ts, + "issue_type": attr_dict.get("Collector Quality Issue", "")} + ) + self.text_annotations.append( + {"label": label, "text": text, "start_ts": start_ts, + "end_ts": end_ts, "attributes": attr_dict} + ) - anchor_indices = before + after - if len(anchor_indices) < 3: + def get_hand_keypoints_at_timestamp(self, timestamp: int) -> HandKeypoints: + result = HandKeypoints() + for hand_type in ("left", "right"): + if timestamp not in self.keypoint_paths[hand_type]: + continue + kp_dict = self.keypoint_paths[hand_type][timestamp] + if len(kp_dict) < NUM_KEYPOINTS // 2: + continue + keypoints = np.full((NUM_KEYPOINTS, 3), INVALID_VALUE, dtype=np.float32) + for kp_idx, xyz in kp_dict.items(): + keypoints[kp_idx] = xyz + if hand_type == "left": + result.left = keypoints + else: + result.right = keypoints + return result + + def get_subgoal_at_timestamp(self, timestamp: int) -> dict[str, Any] | None: + for item in self.subgoal_annotations: + if item["start_ts"] <= timestamp <= item["end_ts"]: + return item return None - x_anchor = np.array(anchor_indices, dtype=np.float64) - y_anchor = keypoints_seq[anchor_indices] # (M, K) + def get_collector_issue_at_timestamp(self, timestamp: int) -> dict[str, Any] | None: + for item in self.collector_issues: + if item["start_ts"] <= timestamp <= item["end_ts"]: + return item + return None - x_gap = np.arange(gap_start, gap_end + 1, dtype=np.float64) - result = np.empty((gap_len, K), dtype=np.float32) + def get_text_annotations_at_timestamp(self, timestamp: int) -> list[dict[str, Any]]: + return [ + ann for ann in self.text_annotations + if ann["start_ts"] <= timestamp <= ann["end_ts"] + ] + + def extract_all_frames_metadata(self) -> list[FrameData]: + frames = [] + for i, ts in enumerate(self.timestamps): + pose = self.pose_values[i] + frames.append( + FrameData( + frame_index=i, + timestamp_us=ts, + camera_pose=CameraPose.from_pose_array(pose), + hand_keypoints=self.get_hand_keypoints_at_timestamp(ts), + text_annotations=self.get_text_annotations_at_timestamp(ts), + subgoal=self.get_subgoal_at_timestamp(ts), + collector_issue=self.get_collector_issue_at_timestamp(ts), + ) + ) + return frames + + def load_all_images(self) -> list[np.ndarray | None]: + """Read every frame of the video sequentially (no seeking). Index i == video frame i.""" + cap = cv2.VideoCapture(self.video_path) + if not cap.isOpened(): + return [] + images: list[np.ndarray | None] = [] + while True: + ret, frame = cap.read() + if not ret: + break + images.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + cap.release() + return images - for k in range(K): - try: - interp = Akima1DInterpolator(x_anchor, y_anchor[:, k]) - result[:, k] = interp(x_gap) - except Exception: - return None - return result +# --------------------------------------------------------------------------- +# Palm pose computation (vectorised where possible) +# --------------------------------------------------------------------------- +def _compute_palm_centroid(keypoints: np.ndarray) -> np.ndarray: + palm_kps = keypoints[PALM_INDICES] + valid_mask = ~np.any(palm_kps >= INVALID_VALUE - 1, axis=1) + if not np.any(valid_mask): + return np.full(3, INVALID_VALUE, dtype=np.float32) + return np.mean(palm_kps[valid_mask], axis=0).astype(np.float32) -def _velocity_check( - interpolated: np.ndarray, - keypoints_seq: np.ndarray, - gap_start: int, - gap_end: int, - max_vel: float, -) -> bool: - """Check that interpolated keypoints don't imply impossible velocities. - Compares each frame to its predecessor (including the anchor frame - immediately before the gap). Returns True if all velocities are ok. - """ - prev_idx = gap_start - 1 - if prev_idx < 0: - return True - prev_kp = keypoints_seq[prev_idx] # (K,) - - # Reshape to (N, 21, 3) for per-keypoint velocity - n_kp = interpolated.shape[1] // 3 - for t in range(interpolated.shape[0]): - cur = interpolated[t].reshape(n_kp, 3) - prv = prev_kp.reshape(n_kp, 3) if t == 0 else interpolated[t - 1].reshape(n_kp, 3) - max_displacement = float(np.max(np.linalg.norm(cur - prv, axis=1))) - if max_displacement > max_vel: - return False - prev_kp = interpolated[t] - return True - - -def _nullify_tracking_errors(frames: list[FrameData]) -> int: - """Null out keypoints on frames with Hand Tracking Error annotations. - - This turns error-flagged frames into gaps so the interpolation logic - can fill short ones and the zero-fill path handles long ones. - Modifies ``frames`` in place. Returns count of nullified frames. +def _compute_palm_orientation(keypoints: np.ndarray, flip_x: bool = False) -> np.ndarray: + """Hand frame: x=right, y=down (palm normal toward ground), z=forward (toward fingers). + flip_x=True for the right hand so that x is rightward for both hands. """ - nullified = 0 - for frame in frames: - if frame.hand_tracking_error is None: - continue - hand = frame.hand_tracking_error.get("hand", "Both") - if hand in ("Left", "Both") and frame.hand_keypoints.left is not None: - frame.hand_keypoints.left = None - if hand in ("Right", "Both") and frame.hand_keypoints.right is not None: - frame.hand_keypoints.right = None - nullified += 1 - return nullified - - -def interpolate_hand_gaps( - frames: list[FrameData], - max_gap_frames: int = MAX_INTERP_GAP_FRAMES, - max_velocity: float = MAX_INTERP_VELOCITY, -) -> dict[str, Any]: - """Fill short gaps in hand keypoints via Akima spline interpolation. - - Frames with Hand Tracking Error annotations have their bad-hand - keypoints nullified first, turning them into gaps eligible for - interpolation. Modifies ``frames`` in place. Returns stats dict. - """ - n = len(frames) - tracking_error_nullified = _nullify_tracking_errors(frames) - stats: dict[str, Any] = { - "tracking_error_nullified": tracking_error_nullified, - "left_gaps_found": 0, "left_gaps_filled": 0, "left_frames_filled": 0, - "right_gaps_found": 0, "right_gaps_filled": 0, "right_frames_filled": 0, - "velocity_rejected": 0, - } - - for hand in ("left", "right"): - missing = np.array([ - getattr(frames[i].hand_keypoints, hand) is None - for i in range(n) - ]) - gaps = _find_gaps(missing) - stats[f"{hand}_gaps_found"] = len(gaps) - - if not gaps: - continue - - # Build dense keypoint array for the hand (N, 63) - kp_seq = np.full((n, 63), INVALID_VALUE, dtype=np.float32) - valid = np.zeros(n, dtype=bool) - for i in range(n): - kp = getattr(frames[i].hand_keypoints, hand) - if kp is not None: - kp_seq[i] = kp.flatten() - valid[i] = True - - for gap_start, gap_end in gaps: - gap_len = gap_end - gap_start + 1 - if gap_len > max_gap_frames: - continue - - filled = _akima_interpolate_keypoints(kp_seq, valid, gap_start, gap_end) - if filled is None: - continue - - if not _velocity_check(filled, kp_seq, gap_start, gap_end, max_velocity): - stats["velocity_rejected"] += 1 - continue - - # Write back into frames - for offset, idx in enumerate(range(gap_start, gap_end + 1)): - kp_21x3 = filled[offset].reshape(21, 3).astype(np.float32) - if hand == "left": - frames[idx].hand_keypoints.left = kp_21x3 - else: - frames[idx].hand_keypoints.right = kp_21x3 - kp_seq[idx] = filled[offset] - valid[idx] = True + wrist, index1, middle1, pinky1 = keypoints[0], keypoints[5], keypoints[9], keypoints[17] + if any(np.any(kp >= INVALID_VALUE - 1) for kp in (wrist, index1, middle1, pinky1)): + return np.zeros(3, dtype=np.float32) + # z: forward — from wrist toward middle finger + z_axis = middle1 - wrist + z_axis /= np.linalg.norm(z_axis) + 1e-8 + # x: right — across palm, orthogonalized against z + # left hand: index1 - pinky1 is rightward + # right hand: pinky1 - index1 is rightward (flip_x=True) + across = (pinky1 - index1) if flip_x else (index1 - pinky1) + across -= np.dot(across, z_axis) * z_axis + x_axis = across / (np.linalg.norm(across) + 1e-8) + # y: down (palm normal toward ground) = cross(z, x) + y_axis = np.cross(z_axis, x_axis) + y_axis /= np.linalg.norm(y_axis) + 1e-8 + rot = np.column_stack([x_axis, y_axis, z_axis]) + try: + return R.from_matrix(rot).as_euler("ZYX", degrees=False).astype(np.float32) + except Exception: + return np.zeros(3, dtype=np.float32) - stats[f"{hand}_gaps_filled"] += 1 - stats[f"{hand}_frames_filled"] += gap_len - return stats +def _compute_palm_6dof(keypoints: np.ndarray, flip_x: bool = False) -> np.ndarray: + centroid = _compute_palm_centroid(keypoints) + if np.any(centroid >= INVALID_VALUE - 1): + return np.full(6, INVALID_VALUE, dtype=np.float32) + ypr = _compute_palm_orientation(keypoints, flip_x=flip_x) + return np.concatenate([centroid, ypr]).astype(np.float32) -# --------------------------------------------------------------------------- -# Frame validity -# --------------------------------------------------------------------------- +def _compute_wrist_position(keypoints: np.ndarray) -> np.ndarray: + wrist = keypoints[0] + if np.any(wrist >= INVALID_VALUE - 1): + return np.full(3, INVALID_VALUE, dtype=np.float32) + return wrist.astype(np.float32) -def _build_validity_mask( - frames: list[FrameData], - video_frame_count: int, - hand_used: str = "Both", -) -> np.ndarray: - """Boolean mask: True = frame is usable for training. - - Called AFTER interpolation. Drop logic: - - Inactive Time → always drop - - Beyond video length → always drop - - Hand Tracking Error with the *active* hand still missing after - interpolation → drop. For single-hand tasks (hand_used="Left" - or "Right"), only errors on the active hand matter. - - Missing keypoints on the inactive hand → zero-fill, keep +def _compute_wrist_orientation(keypoints: np.ndarray, flip_x: bool = False) -> np.ndarray: + """Hand frame: x=right, y=down (palm normal toward ground), z=forward (toward fingers). + flip_x=True for the right hand so that x is rightward for both hands. """ - active_hands: set[str] = set() - if hand_used in ("Left", "Both"): - active_hands.add("Left") - if hand_used in ("Right", "Both"): - active_hands.add("Right") - - n = len(frames) - mask = np.ones(n, dtype=bool) - drop_reasons: dict[str, int] = { - "tracking_error_unfilled": 0, - "inactive_time": 0, - "beyond_video": 0, - } - missing_hands_info = {"left_missing": 0, "right_missing": 0, "both_missing": 0} - for i, frame in enumerate(frames): - if ( - frame.collector_issue is not None - and frame.collector_issue.get("issue_type") == "Inactive Time" - ): - mask[i] = False - drop_reasons["inactive_time"] += 1 - continue - if i >= video_frame_count: - mask[i] = False - drop_reasons["beyond_video"] += 1 - continue + wrist, index1, middle1, pinky1 = keypoints[0], keypoints[5], keypoints[9], keypoints[17] + if any(np.any(kp >= INVALID_VALUE - 1) for kp in (wrist, index1, middle1, pinky1)): + return np.zeros(3, dtype=np.float32) + # z: forward — from wrist toward middle finger + z_axis = middle1 - wrist + z_axis /= np.linalg.norm(z_axis) + 1e-8 + # x: right — across palm, orthogonalized against z + # left hand: index1 - pinky1 is rightward + # right hand: pinky1 - index1 is rightward (flip_x=True) + across = (pinky1 - index1) if flip_x else (index1 - pinky1) + across -= np.dot(across, z_axis) * z_axis + x_axis = across / (np.linalg.norm(across) + 1e-8) + # y: down (palm normal toward ground) = cross(z, x) + y_axis = np.cross(z_axis, x_axis) + y_axis /= np.linalg.norm(y_axis) + 1e-8 + rot = np.column_stack([x_axis, y_axis, z_axis]) + try: + return R.from_matrix(rot).as_euler("ZYX", degrees=False).astype(np.float32) + except Exception: + return np.zeros(3, dtype=np.float32) - # Tracking-error frames whose *active*-hand keypoints weren't - # recovered by interpolation → drop. - if frame.hand_tracking_error is not None: - err_hand = frame.hand_tracking_error.get("hand", "Both") - err_hands_set: set[str] = set() - if err_hand in ("Left", "Both"): - err_hands_set.add("Left") - if err_hand in ("Right", "Both"): - err_hands_set.add("Right") - - still_bad = False - for h in err_hands_set & active_hands: - kp = frame.hand_keypoints.left if h == "Left" else frame.hand_keypoints.right - if kp is None: - still_bad = True - break - if still_bad: - mask[i] = False - drop_reasons["tracking_error_unfilled"] += 1 - continue - l_miss = frame.hand_keypoints.left is None - r_miss = frame.hand_keypoints.right is None - if l_miss and r_miss: - missing_hands_info["both_missing"] += 1 - elif l_miss: - missing_hands_info["left_missing"] += 1 - elif r_miss: - missing_hands_info["right_missing"] += 1 - - valid_count = int(mask.sum()) - total = len(frames) - dropped = total - valid_count - print(f" Validity: {valid_count}/{total} frames kept ({dropped} dropped)") - for reason, count in drop_reasons.items(): - if count > 0: - print(f" {reason}: {count}") - zero_filled = sum(missing_hands_info.values()) - if zero_filled > 0: - print(f" Zero-filled (kept): left={missing_hands_info['left_missing']}, " - f"right={missing_hands_info['right_missing']}, " - f"both={missing_hands_info['both_missing']}") - return mask - - -def _contiguous_runs(mask: np.ndarray, min_length: int) -> list[list[int]]: - """Extract contiguous runs of True indices from a boolean mask.""" - runs: list[list[int]] = [] - current: list[int] = [] - for i, val in enumerate(mask): - if val: - current.append(i) - else: - if len(current) >= min_length: - runs.append(current) - current = [] - if len(current) >= min_length: - runs.append(current) - return runs +def _compute_wrist_6dof(keypoints: np.ndarray, flip_x: bool = False) -> np.ndarray: + wrist_xyz = _compute_wrist_position(keypoints) + if np.any(wrist_xyz >= INVALID_VALUE - 1): + return np.full(6, INVALID_VALUE, dtype=np.float32) + wrist_ypr = _compute_wrist_orientation(keypoints, flip_x=flip_x) + return np.concatenate([wrist_xyz, wrist_ypr]).astype(np.float32) # --------------------------------------------------------------------------- # Language annotations # --------------------------------------------------------------------------- - def _build_language_annotations(sub_frames: list[FrameData]) -> list[tuple[str, int, int]]: rows: list[tuple[str, int, int]] = [] current_text: str | None = None @@ -509,24 +401,15 @@ def _task_description(frames: list[FrameData], demo_meta: dict[str, Any]) -> str # Core conversion # --------------------------------------------------------------------------- - def convert_task_to_zarr( task_id: str, output_dir: str, download_dir: str, robot_type: str = "scale_bimanual", fps: int = 30, - img_workers: int | None = None, - max_interp_gap: int = MAX_INTERP_GAP_FRAMES, - max_interp_velocity: float = MAX_INTERP_VELOCITY, -) -> dict[str, Any]: - """Convert one Scale task to one or more Zarr episodes. - - Returns a dict with keys: episodes, folder, task_desc, total_frames, output_dir. - """ +) -> int: + """Convert one Scale task to one or more Zarr episodes. Returns count.""" t_start = time.perf_counter() - if img_workers is None: - img_workers = min(os.cpu_count() or 4, 8) print(f"[{task_id}] Fetching task metadata...") task_download_path = os.path.join(download_dir, task_id) @@ -537,6 +420,7 @@ def convert_task_to_zarr( raise ValueError(f"Task {task_id} not found or Scale API failed") print(f"[{task_id}] Downloading files...") + t_dl = time.perf_counter() local_paths = download_from_simple_response_dict(task_download_path, response) sfs_path = local_paths.get("sfs") annotations_path = local_paths.get("annotations") @@ -549,6 +433,7 @@ def _nonempty(p: str | None) -> bool: if not (_nonempty(sfs_path) and _nonempty(annotations_path)): raise ValueError(f"Downloaded SFS/annotation files are empty for task {task_id}") + print(f"[{task_id}] Downloaded in {time.perf_counter() - t_dl:.1f}s") print(f"[{task_id}] Loading SFS metadata...") try: @@ -565,77 +450,21 @@ def _nonempty(p: str | None) -> bool: extractor = SFSDataExtractor(sfs_path, annotations_path, video_path) frames = extractor.extract_all_frames_metadata() n_frames = len(frames) - if n_frames < MIN_EPISODE_FRAMES: + if n_frames <= ACTION_WINDOW: raise ValueError(f"Task {task_id} has too few frames ({n_frames})") - task_desc = _task_description(frames, extractor.demonstration_metadata) - - # Detect handedness from annotation "Hand Used" attribute - hand_used = str(extractor.demonstration_metadata.get("Hand Used", "Both")).strip() - if hand_used not in ("Left", "Right", "Both"): - hand_used = "Both" - resolved_embodiment = HAND_USED_TO_EMBODIMENT.get(hand_used, robot_type) - if hand_used != "Both": - print(f"[{task_id}] Single-hand task: Hand Used={hand_used} -> {resolved_embodiment}") - - # Extract and scale intrinsics to output image resolution - raw_intrinsics = get_intrinsics( - load_scene(sfs_path), "left_rectified" - ) - camera_intrinsics: dict[str, Any] | None = None - if raw_intrinsics: - orig_w = raw_intrinsics.get("width", 1920) - orig_h = raw_intrinsics.get("height", 1200) - sx = IMAGE_SIZE[0] / orig_w - sy = IMAGE_SIZE[1] / orig_h - camera_intrinsics = { - "fx": raw_intrinsics["fx"] * sx, - "fy": raw_intrinsics["fy"] * sy, - "cx": raw_intrinsics["cx"] * sx, - "cy": raw_intrinsics["cy"] * sy, - "width": IMAGE_SIZE[0], - "height": IMAGE_SIZE[1], - } + print(f"[{task_id}] Loading all video frames sequentially...") + t_vid = time.perf_counter() + all_images = extractor.load_all_images() + print(f"[{task_id}] Loaded {len(all_images)} video frames in {time.perf_counter() - t_vid:.1f}s (SFS frames={n_frames})") + if len(all_images) != n_frames: + print(f"[{task_id}] WARNING: video frame count ({len(all_images)}) != SFS frame count ({n_frames}) — index drift possible") - # ------------------------------------------------------------------ - # Probe video frame count - # ------------------------------------------------------------------ - video_frame_count = _get_video_frame_count(video_path) - print(f"[{task_id}] Video: {video_frame_count} frames SFS: {n_frames} frames") - if video_frame_count != n_frames: - print(f"[{task_id}] WARNING: video/SFS frame count mismatch") - - # ------------------------------------------------------------------ - # Interpolate short hand-tracking gaps (before filtering) - # ------------------------------------------------------------------ - if max_interp_gap > 0: - interp_stats = interpolate_hand_gaps( - frames, - max_gap_frames=max_interp_gap, - max_velocity=max_interp_velocity, - ) - nullified = interp_stats["tracking_error_nullified"] - filled_l = interp_stats["left_frames_filled"] - filled_r = interp_stats["right_frames_filled"] - rej = interp_stats["velocity_rejected"] - print(f"[{task_id}] Interpolation: nullified {nullified} tracking-error frames, " - f"filled {filled_l} left + {filled_r} right frames " - f"(gaps: {interp_stats['left_gaps_filled']}L/{interp_stats['right_gaps_filled']}R filled, " - f"{rej} velocity-rejected)") - - # ------------------------------------------------------------------ - # Build per-frame validity mask and find contiguous runs - # ------------------------------------------------------------------ - validity = _build_validity_mask(frames, video_frame_count, hand_used=hand_used) - runs = _contiguous_runs(validity, min_length=MIN_EPISODE_FRAMES) - if not runs: - raise ValueError(f"Task {task_id} has no valid contiguous runs after filtering") - - print(f"[{task_id}] {len(runs)} contiguous run(s), " - f"total valid frames: {sum(len(r) for r in runs)}") + task_desc = _task_description(frames, extractor.demonstration_metadata) + valid_frame_count = n_frames - ACTION_WINDOW # ------------------------------------------------------------------ - # Precompute all per-frame data into dense arrays (no video needed) + # Precompute all per-frame data into dense arrays (once) # ------------------------------------------------------------------ left_world_6 = np.full((n_frames, 6), INVALID_VALUE, dtype=np.float32) right_world_6 = np.full((n_frames, 6), INVALID_VALUE, dtype=np.float32) @@ -647,64 +476,95 @@ def _nonempty(p: str | None) -> bool: for i, frame in enumerate(frames): if frame.hand_keypoints.left is not None: - left_world_6[i] = compute_palm_6dof(frame.hand_keypoints.left) - left_wrist_6[i] = compute_wrist_6dof(frame.hand_keypoints.left) + left_world_6[i] = _compute_palm_6dof(frame.hand_keypoints.left) + left_wrist_6[i] = _compute_wrist_6dof(frame.hand_keypoints.left) left_kps[i] = frame.hand_keypoints.left.flatten().astype(np.float32) if frame.hand_keypoints.right is not None: - right_world_6[i] = compute_palm_6dof(frame.hand_keypoints.right, flip_x=True) - right_wrist_6[i] = compute_wrist_6dof(frame.hand_keypoints.right, flip_x=True) + right_world_6[i] = _compute_palm_6dof(frame.hand_keypoints.right, flip_x=True) + right_wrist_6[i] = _compute_wrist_6dof(frame.hand_keypoints.right, flip_x=True) right_kps[i] = frame.hand_keypoints.right.flatten().astype(np.float32) head_pose_6[i, :3] = frame.camera_pose.position.astype(np.float32) head_pose_6[i, 3:] = R.from_matrix(frame.camera_pose.rotation_matrix).as_euler( "ZYX", degrees=False ).astype(np.float32) - left_world = batch_pose6_to_pose7(left_world_6) - right_world = batch_pose6_to_pose7(right_world_6) - left_wrist = batch_pose6_to_pose7(left_wrist_6) - right_wrist = batch_pose6_to_pose7(right_wrist_6) - head_pose_world = batch_pose6_to_pose7(head_pose_6) + # Batch-convert all (N, 6) [xyz + euler ZYX] -> (N, 7) [xyz + quat xyzw] + left_world = _batch_pose6_to_pose7(left_world_6) + right_world = _batch_pose6_to_pose7(right_world_6) + left_wrist = _batch_pose6_to_pose7(left_wrist_6) + right_wrist = _batch_pose6_to_pose7(right_wrist_6) + head_pose_world = _batch_pose6_to_pose7(head_pose_6) # ------------------------------------------------------------------ - # Split contiguous runs into sub-episodes and write + # Filter valid frame indices (same criteria as old script) # ------------------------------------------------------------------ - sub_episode_plans: list[list[int]] = [] - for run in runs: - for ep_start in range(0, len(run), SUB_EPISODE_LENGTH): - sub = run[ep_start : ep_start + SUB_EPISODE_LENGTH] - if len(sub) >= MIN_EPISODE_FRAMES: - sub_episode_plans.append(sub) + valid_indices: list[int] = [] + for t in range(valid_frame_count): + if ( + frames[t].collector_issue is not None + and frames[t].collector_issue.get("issue_type") == "Inactive Time" + ): + continue + window = slice(t, t + ACTION_WINDOW) + n_invalid = ( + np.sum(np.any(left_world[window] >= INVALID_VALUE - 1, axis=1)) + + np.sum(np.any(right_world[window] >= INVALID_VALUE - 1, axis=1)) + ) + if n_invalid > ACTION_WINDOW: # >50% of 2*ACTION_WINDOW + continue + valid_indices.append(t) + + if not valid_indices: + raise ValueError(f"Task {task_id} has no valid frames after filtering") + + print(f"[{task_id}] {len(valid_indices)} valid frames out of {valid_frame_count}") + # ------------------------------------------------------------------ + # Write sub-episodes + # ------------------------------------------------------------------ folder = datetime.now(timezone.utc).strftime("%Y-%m-%d-%H-%M-%S-%f") task_output_dir = Path(output_dir) / folder task_output_dir.mkdir(parents=True, exist_ok=True) written = 0 + for ep_start in range(0, len(valid_indices), SUB_EPISODE_LENGTH): + sub = valid_indices[ep_start : ep_start + SUB_EPISODE_LENGTH] + if len(sub) < 10: + continue - for sub in sub_episode_plans: - decoded = _decode_selected_frames(video_path, sub) - kept = [t for t in sub if t in decoded] - if len(kept) < MIN_EPISODE_FRAMES: - del decoded + # First pass: figure out which frames have images + kept: list[int] = [] + none_count = 0 + for t in sub: + img = all_images[t] if t < len(all_images) else None + if img is not None: + kept.append(t) + else: + none_count += 1 + print(f"[ep{written}] sub={len(sub)} kept={len(kept)} dropped(no image)={none_count} frames=[{sub[0]}..{sub[-1]}]") + if len(kept) < 10: continue T = len(kept) - kept_arr = np.array(kept) - # All kept frames should have valid hand tracking; - # replace any remaining per-keypoint INVALID_VALUE sentinels with 0.0 - left_curr_7 = np.where( - left_world[kept_arr] >= INVALID_VALUE - 1, 0.0, left_world[kept_arr] - ).astype(np.float32) - right_curr_7 = np.where( - right_world[kept_arr] >= INVALID_VALUE - 1, 0.0, right_world[kept_arr] - ).astype(np.float32) + # ---- Per-frame current state (vectorised) ---- + kept_arr = np.array(kept) + left_curr_7 = left_world[kept_arr] # (T, 7) + right_curr_7 = right_world[kept_arr] + left_curr_7 = np.where(left_curr_7 >= INVALID_VALUE - 1, 0.0, left_curr_7).astype( + np.float32 + ) + right_curr_7 = np.where(right_curr_7 >= INVALID_VALUE - 1, 0.0, right_curr_7).astype( + np.float32 + ) left_wrist_curr_7 = np.where( left_wrist[kept_arr] >= INVALID_VALUE - 1, 0.0, left_wrist[kept_arr] ).astype(np.float32) right_wrist_curr_7 = np.where( right_wrist[kept_arr] >= INVALID_VALUE - 1, 0.0, right_wrist[kept_arr] ).astype(np.float32) + + # Head pose & keypoints actions_head = head_pose_world[kept_arr] left_keypoints = np.where( left_kps[kept_arr] >= INVALID_VALUE - 1, 0.0, left_kps[kept_arr] @@ -713,21 +573,15 @@ def _nonempty(p: str | None) -> bool: right_kps[kept_arr] >= INVALID_VALUE - 1, 0.0, right_kps[kept_arr] ).astype(np.float32) - ordered_frames = [decoded[t] for t in kept] - del decoded - - # Preview MP4 (before JPEG encoding to avoid re-decoding) - preview_path = task_output_dir / f"{task_id}_episode_{written:06d}.mp4" - _save_preview_mp4(ordered_frames, preview_path, fps=fps) - - n_workers = min(img_workers, T) - with ThreadPoolExecutor(max_workers=n_workers) as pool: - encode_results = list(pool.map(_resize_and_encode, ordered_frames)) - del ordered_frames - image_shape = list(encode_results[0][0]) - pre_encoded = np.array([r[1] for r in encode_results], dtype=object) - del encode_results + # ---- Build image array ---- + images = np.stack( + [cv2.resize(all_images[t], IMAGE_SIZE, interpolation=cv2.INTER_LINEAR) + for t in kept], + axis=0, + ).astype(np.uint8) + print(f"[ep{written}] images.shape={images.shape} kept_arr.shape={kept_arr.shape} match={images.shape[0] == len(kept_arr)}") + # ---- Numeric data ---- numeric_data = { "left.obs_ee_pose": left_curr_7, "right.obs_ee_pose": right_curr_7, @@ -737,49 +591,40 @@ def _nonempty(p: str | None) -> bool: "right.obs_wrist_pose": right_wrist_curr_7, "obs_head_pose": actions_head, } + image_data = { + "images.front_1": images, + } used_frames = [frames[t] for t in kept] lang_ann = _build_language_annotations(used_frames) episode_path = task_output_dir / f"{task_id}_episode_{written:06d}.zarr" - meta_override = {"hand_used": hand_used} - if camera_intrinsics: - meta_override["camera_intrinsics"] = camera_intrinsics ZarrWriter.create_and_write( episode_path=episode_path, numeric_data=numeric_data, - pre_encoded_image_data={ - "images.front_1": (pre_encoded, image_shape), - }, - embodiment=resolved_embodiment, + image_data=image_data, + embodiment=robot_type, fps=fps, task=task_desc, annotations=lang_ann if lang_ann else None, - enable_sharding=True, - metadata_override=meta_override or None, + enable_sharding=False, ) written += 1 print(f"[{task_id}] Wrote episode {written} ({T} frames) -> {episode_path.name}") + # Clean download cache if os.path.exists(task_download_path): - shutil.rmtree(task_download_path, ignore_errors=True) + shutil.rmtree(task_download_path) elapsed = time.perf_counter() - t_start print(f"[{task_id}] Done: {written} episode(s) in {elapsed:.1f}s -> {task_output_dir}") - return { - "episodes": written, - "folder": folder, - "task_desc": task_desc, - "total_frames": n_frames, - "output_dir": str(task_output_dir), - } + return written # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- - def main() -> int: parser = argparse.ArgumentParser( description="Convert Scale SFS tasks to EgoVerse Zarr episodes" @@ -789,73 +634,28 @@ def main() -> int: parser.add_argument("--download-dir", default="scale_data", help="Temp download cache") parser.add_argument("--robot-type", default="scale_bimanual", help="Embodiment tag") parser.add_argument("--fps", type=int, default=30) - parser.add_argument( - "--workers", type=int, default=1, - help="Parallel task workers (default: 1 = sequential)", - ) - parser.add_argument( - "--max-interp-gap", type=int, default=MAX_INTERP_GAP_FRAMES, - help=f"Max gap length (frames) to interpolate (default: {MAX_INTERP_GAP_FRAMES})", - ) - parser.add_argument( - "--max-interp-velocity", type=float, default=MAX_INTERP_VELOCITY, - help=f"Max per-frame velocity (m) for interpolation sanity check (default: {MAX_INTERP_VELOCITY})", - ) args = parser.parse_args() Path(args.output_dir).mkdir(parents=True, exist_ok=True) Path(args.download_dir).mkdir(parents=True, exist_ok=True) - img_workers = max(1, (os.cpu_count() or 4) // max(args.workers, 1)) - total_episodes = 0 failed: list[str] = [] - - if args.workers > 1: - print(f"Running with {args.workers} parallel workers " - f"({img_workers} image threads per worker)") - with ProcessPoolExecutor(max_workers=args.workers) as pool: - futures = { - pool.submit( - convert_task_to_zarr, - task_id=tid, - output_dir=args.output_dir, - download_dir=args.download_dir, - robot_type=args.robot_type, - fps=args.fps, - img_workers=img_workers, - max_interp_gap=args.max_interp_gap, - max_interp_velocity=args.max_interp_velocity, - ): tid - for tid in args.task_ids - } - for future in as_completed(futures): - tid = futures[future] - try: - total_episodes += future.result()["episodes"] - except Exception as exc: - print(f"[{tid}] ERROR: {exc}") - traceback.print_exc() - failed.append(tid) - else: - for idx, task_id in enumerate(args.task_ids, start=1): - print(f"\n[{idx}/{len(args.task_ids)}] {task_id}") - try: - result = convert_task_to_zarr( - task_id=task_id, - output_dir=args.output_dir, - download_dir=args.download_dir, - robot_type=args.robot_type, - fps=args.fps, - img_workers=img_workers, - max_interp_gap=args.max_interp_gap, - max_interp_velocity=args.max_interp_velocity, - ) - total_episodes += result["episodes"] - except Exception as exc: - print(f"[{task_id}] ERROR: {exc}") - traceback.print_exc() - failed.append(task_id) + for idx, task_id in enumerate(args.task_ids, start=1): + print(f"\n[{idx}/{len(args.task_ids)}] {task_id}") + try: + n = convert_task_to_zarr( + task_id=task_id, + output_dir=args.output_dir, + download_dir=args.download_dir, + robot_type=args.robot_type, + fps=args.fps, + ) + total_episodes += n + except Exception as exc: + print(f"[{task_id}] ERROR: {exc}") + traceback.print_exc() + failed.append(task_id) print(f"\n{'=' * 60}") print(f"Conversion complete: {len(args.task_ids)} tasks, " diff --git a/external/scale/scripts/sfs_zarr_pipeline.py b/external/scale/scripts/sfs_zarr_pipeline.py deleted file mode 100644 index e51ef995..00000000 --- a/external/scale/scripts/sfs_zarr_pipeline.py +++ /dev/null @@ -1,528 +0,0 @@ -#!/usr/bin/env python3 -""" -End-to-end Scale SFS -> EgoVerse Zarr pipeline. - -Downloads tasks from the Scale API, converts to Zarr, uploads to S3, -registers in SQL, and cleans up local files. - -Usage: - # Full pipeline from CSV - python sfs_zarr_pipeline.py --csv delivery.csv --workers 6 \\ - --upload-s3 --register-sql --delete-local - - # Single task - python sfs_zarr_pipeline.py --task-ids TASK1 --upload-s3 --register-sql - - # Convert only (no S3/SQL) - python sfs_zarr_pipeline.py --task-ids TASK1 TASK2 - -Environment: - SCALE_API_KEY Required for downloading tasks from Scale - R2_ACCESS_KEY_ID Cloudflare R2 access key (from ~/.egoverse_env) - R2_SECRET_ACCESS_KEY Cloudflare R2 secret key - R2_ENDPOINT_URL Cloudflare R2 endpoint URL -""" - -from __future__ import annotations - -import argparse -import io -import os -import shutil -import subprocess -import sys -import time -import traceback -from concurrent.futures import ThreadPoolExecutor, as_completed -from datetime import datetime, timezone -from pathlib import Path -from typing import Any - -from sfs_to_egoverse_zarr import convert_task_to_zarr - - -# --------------------------------------------------------------------------- -# SQL helpers (lazy imports to avoid torch dependency at module level) -# --------------------------------------------------------------------------- - -_sql_engine = None - - -def _get_sql_engine(): - global _sql_engine - if _sql_engine is None: - from egomimic.utils.aws.aws_sql import create_default_engine - - old_stdout = sys.stdout - sys.stdout = io.StringIO() - try: - _sql_engine = create_default_engine() - finally: - sys.stdout = old_stdout - return _sql_engine - - -def is_task_already_processed(task_id: str) -> str | None: - """Returns zarr_processed_path if already done, else None.""" - from egomimic.utils.aws.aws_sql import episode_hash_to_table_row - - engine = _get_sql_engine() - row = episode_hash_to_table_row(engine, task_id) - if row and row.zarr_processed_path: - return row.zarr_processed_path - return None - - -def register_in_sql( - task_id: str, - task_desc: str, - total_frames: int, - s3_path: str, -) -> bool: - """Register a converted zarr dataset in the SQL episode table.""" - from egomimic.utils.aws.aws_sql import TableRow, add_episode, update_episode - - engine = _get_sql_engine() - row = TableRow( - episode_hash=task_id, - operator="scale", - lab="scale", - task=task_desc, - embodiment="scale", - robot_name="scale_bimanual", - num_frames=total_frames, - task_description=task_desc, - scene="unknown", - objects="", - zarr_processed_path=s3_path, - is_deleted=False, - is_eval=False, - eval_score=-1, - eval_success=True, - ) - - try: - add_episode(engine, row) - return True - except Exception: - try: - update_episode(engine, row) - return True - except Exception as exc: - print(f"[{task_id}] SQL registration failed: {exc}") - return False - - -# --------------------------------------------------------------------------- -# S3 upload via aws cli (uses optimised C transfer engine) -# --------------------------------------------------------------------------- - - -def upload_to_s3( - local_dir: str, - bucket: str, - s3_prefix: str, - aws_access_key_id: str = "", - aws_secret_access_key: str = "", - endpoint_url: str = "", - delete_after: bool = False, -) -> str: - """Upload a local directory tree to S3-compatible storage via `aws s3 sync`. - - Returns the full s3:// URI of the uploaded prefix. - Supports Cloudflare R2 via endpoint_url. - """ - local_path = Path(local_dir) - if not local_path.exists(): - raise FileNotFoundError(f"Directory not found: {local_dir}") - - s3_uri = f"s3://{bucket}/{s3_prefix}/" - env = {**os.environ} - if aws_access_key_id and aws_secret_access_key: - env["AWS_ACCESS_KEY_ID"] = aws_access_key_id - env["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key - if endpoint_url: - env["AWS_DEFAULT_REGION"] = "auto" - - cmd = ["aws", "s3", "sync", str(local_path), s3_uri, "--only-show-errors"] - if endpoint_url: - cmd.extend(["--endpoint-url", endpoint_url]) - - result = subprocess.run(cmd, capture_output=True, text=True, env=env) - if result.returncode != 0: - raise RuntimeError(f"aws s3 sync failed: {result.stderr.strip()}") - - if delete_after: - shutil.rmtree(local_dir, ignore_errors=True) - - return s3_uri - - -# --------------------------------------------------------------------------- -# Single-task pipeline -# --------------------------------------------------------------------------- - - -def run_task( - task_id: str, - output_dir: str, - download_dir: str, - robot_type: str, - fps: int, - img_workers: int, - *, - is_flagship: bool = False, - upload_s3: bool = False, - s3_bucket: str = "rldb", - aws_key: str = "", - aws_secret: str = "", - endpoint_url: str = "", - do_register_sql: bool = False, - delete_local: bool = False, - max_interp_gap: int = 15, - max_interp_velocity: float = 2.0, -) -> dict[str, Any]: - """Run the full pipeline for one task: convert -> upload -> register -> cleanup.""" - t0 = time.perf_counter() - - result = convert_task_to_zarr( - task_id=task_id, - output_dir=output_dir, - download_dir=download_dir, - robot_type=robot_type, - fps=fps, - img_workers=img_workers, - max_interp_gap=max_interp_gap, - max_interp_velocity=max_interp_velocity, - ) - - episodes = result["episodes"] - folder = result["folder"] - task_desc = result["task_desc"] - total_frames = result["total_frames"] - local_dir = result["output_dir"] - - s3_full_path = "" - category = "flagship" if is_flagship else "freeform" - - if upload_s3 and folder: - s3_prefix = f"processed_v3/scale/{category}/{folder}" - print(f"[{task_id}] Uploading to s3://{s3_bucket}/{s3_prefix}/ ...") - s3_full_path = upload_to_s3( - local_dir, - s3_bucket, - s3_prefix, - aws_access_key_id=aws_key, - aws_secret_access_key=aws_secret, - endpoint_url=endpoint_url, - delete_after=delete_local, - ) - print(f"[{task_id}] Uploaded -> {s3_full_path}") - elif delete_local and folder and os.path.exists(local_dir): - shutil.rmtree(local_dir, ignore_errors=True) - - sql_ok = False - if do_register_sql and folder and s3_full_path: - sql_ok = register_in_sql(task_id, task_desc, total_frames, s3_full_path) - if sql_ok: - print(f"[{task_id}] Registered in SQL ({category})") - - elapsed = time.perf_counter() - t0 - return { - "task_id": task_id, - "episodes": episodes, - "folder": folder, - "total_frames": total_frames, - "s3_path": s3_full_path, - "sql_registered": sql_ok, - "elapsed": elapsed, - "category": category, - } - - -# --------------------------------------------------------------------------- -# CSV helpers -# --------------------------------------------------------------------------- - - -def load_tasks_from_csv(csv_path: str) -> list[tuple[str, bool]]: - """Load (task_id, is_flagship) pairs from CSV.""" - import pandas as pd - - df = pd.read_csv(csv_path) - flagship_col = df["Is Flagship"].astype(str).str.strip().str.upper() == "TRUE" - return [ - (str(row["TASK_ID"]).strip(), bool(flagship_col.iloc[i])) - for i, (_, row) in enumerate(df.iterrows()) - ] - - -def log_task_result( - progress_file: str, - task_id: str, - folder: str, - s3_path: str, - episodes: int, - status: str, -) -> None: - with open(progress_file, "a") as f: - ts = datetime.now(timezone.utc).isoformat() - f.write(f"{task_id},{folder},{s3_path},{episodes},{status},{ts}\n") - - -# --------------------------------------------------------------------------- -# CLI -# --------------------------------------------------------------------------- - - -def main() -> int: - parser = argparse.ArgumentParser( - description="End-to-end Scale SFS -> EgoVerse Zarr pipeline", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__, - ) - - input_group = parser.add_mutually_exclusive_group(required=True) - input_group.add_argument("--task-ids", nargs="+", help="Scale task IDs") - input_group.add_argument( - "--csv", help="CSV file with TASK_ID and Is Flagship columns" - ) - - parser.add_argument( - "--output-dir", default="egoverse_zarr_dataset", help="Output root" - ) - parser.add_argument( - "--download-dir", default="scale_data", help="Temp download cache" - ) - parser.add_argument( - "--robot-type", default="scale_bimanual", help="Embodiment tag" - ) - parser.add_argument("--fps", type=int, default=30) - parser.add_argument( - "--workers", - type=int, - default=6, - help="Parallel task workers (default: 6)", - ) - parser.add_argument( - "--progress-file", - default="zarr_pipeline_progress.csv", - help="Log file for processed tasks", - ) - parser.add_argument( - "--limit", - type=int, - default=0, - help="Max tasks to process (0 = no limit)", - ) - - s3_group = parser.add_argument_group("S3/R2 upload") - s3_group.add_argument("--upload-s3", action="store_true", help="Upload to S3/R2") - s3_group.add_argument("--s3-bucket", default="rldb", help="S3 bucket") - s3_group.add_argument("--endpoint-url", help="S3-compatible endpoint URL (for R2)") - s3_group.add_argument( - "--delete-local", - action="store_true", - help="Delete local files after upload", - ) - - sql_group = parser.add_argument_group("SQL registration") - sql_group.add_argument( - "--register-sql", - action="store_true", - help="Register episodes in SQL database", - ) - - interp_group = parser.add_argument_group("Interpolation") - interp_group.add_argument( - "--max-interp-gap", type=int, default=15, - help="Max gap length (frames) to interpolate (default: 15 = 0.5s@30fps)", - ) - interp_group.add_argument( - "--max-interp-velocity", type=float, default=2.0, - help="Max per-frame displacement (m) for interpolation sanity check (default: 2.0)", - ) - - args = parser.parse_args() - - # ------------------------------------------------------------------ - # Resolve task list with flagship/freeform metadata - # ------------------------------------------------------------------ - task_entries: list[tuple[str, bool]] = [] # (task_id, is_flagship) - if args.csv: - task_entries = load_tasks_from_csv(args.csv) - else: - task_entries = [(tid, False) for tid in args.task_ids] - - if args.limit > 0: - task_entries = task_entries[: args.limit] - - if not task_entries: - print("No tasks to process.") - return 0 - - aws_key = os.environ.get("R2_ACCESS_KEY_ID", os.environ.get("AWS_ACCESS_KEY_ID", "")) - aws_secret = os.environ.get("R2_SECRET_ACCESS_KEY", os.environ.get("AWS_SECRET_ACCESS_KEY", "")) - endpoint_url = args.endpoint_url or os.environ.get("R2_ENDPOINT_URL", "") - - Path(args.output_dir).mkdir(parents=True, exist_ok=True) - Path(args.download_dir).mkdir(parents=True, exist_ok=True) - - img_workers = max(1, (os.cpu_count() or 4) // max(args.workers, 1)) - - # ------------------------------------------------------------------ - # SQL idempotency check: filter out already-processed tasks - # ------------------------------------------------------------------ - if args.register_sql: - print(f"Checking SQL for {len(task_entries)} tasks...") - _get_sql_engine() - original_count = len(task_entries) - filtered: list[tuple[str, bool]] = [] - skipped = 0 - for idx, (task_id, is_flagship) in enumerate(task_entries): - if (idx + 1) % 100 == 0 or idx == 0: - print(f" SQL check {idx + 1}/{original_count}...") - existing = is_task_already_processed(task_id) - if existing: - skipped += 1 - print(f"[SKIP] {task_id} already processed -> {existing}") - else: - filtered.append((task_id, is_flagship)) - task_entries = filtered - print(f"SQL check done: {skipped} skipped, {len(task_entries)} to process") - if skipped: - print(f"Skipped {skipped}/{original_count} already-processed tasks") - - if not task_entries: - print("All tasks already processed.") - return 0 - - n_flagship = sum(1 for _, f in task_entries if f) - n_freeform = len(task_entries) - n_flagship - - # Header - print() - print("=" * 60) - print(" Scale SFS -> EgoVerse Zarr Pipeline") - print("=" * 60) - print(f" Tasks: {len(task_entries)} ({n_flagship} flagship, {n_freeform} freeform)") - print(f" Workers: {args.workers}") - print(f" Img threads/worker: {img_workers}") - print(f" Output: {args.output_dir}//") - if args.upload_s3: - dest = "R2" if endpoint_url else "S3" - print(f" {dest}: s3://{args.s3_bucket}/processed_v3/scale/{{flagship,freeform}}//") - if args.register_sql: - print(" SQL: enabled (zarr_processed_path)") - print(f" Progress: {args.progress_file}") - print("=" * 60) - print() - - total_episodes = 0 - failed: list[str] = [] - results: list[dict[str, Any]] = [] - start_time = time.perf_counter() - - def _process_one(task_id: str, is_flagship: bool, idx: int) -> dict[str, Any] | None: - print(f"[{idx}/{len(task_entries)}] {task_id} ({'flagship' if is_flagship else 'freeform'})") - try: - res = run_task( - task_id=task_id, - output_dir=args.output_dir, - download_dir=args.download_dir, - robot_type=args.robot_type, - fps=args.fps, - img_workers=img_workers, - is_flagship=is_flagship, - upload_s3=args.upload_s3, - s3_bucket=args.s3_bucket, - aws_key=aws_key, - aws_secret=aws_secret, - endpoint_url=endpoint_url, - do_register_sql=args.register_sql, - delete_local=args.delete_local, - max_interp_gap=args.max_interp_gap, - max_interp_velocity=args.max_interp_velocity, - ) - log_task_result( - args.progress_file, - task_id, - res["folder"], - res["s3_path"], - res["episodes"], - "ok", - ) - print( - f" {task_id}: {res['episodes']} eps, " - f"{res['total_frames']} frames ({res['elapsed']:.1f}s) " - f"-> {res['category']}" - ) - return res - except Exception as exc: - print(f"[{task_id}] ERROR: {exc}") - traceback.print_exc() - log_task_result(args.progress_file, task_id, "", "", 0, f"failed: {str(exc)[:80]}") - return None - - if args.workers > 1: - print( - f"Running with {args.workers} parallel workers " - f"({img_workers} image threads per worker)\n" - ) - with ThreadPoolExecutor(max_workers=args.workers) as pool: - futures = { - pool.submit(_process_one, tid, is_fl, idx): tid - for idx, (tid, is_fl) in enumerate(task_entries, start=1) - } - for future in as_completed(futures): - tid = futures[future] - try: - res = future.result() - if res: - total_episodes += res["episodes"] - results.append(res) - else: - failed.append(tid) - except Exception as exc: - print(f" {tid}: UNEXPECTED ERROR — {exc}") - traceback.print_exc() - failed.append(tid) - else: - for idx, (task_id, is_flagship) in enumerate(task_entries, start=1): - res = _process_one(task_id, is_flagship, idx) - if res: - total_episodes += res["episodes"] - results.append(res) - else: - failed.append(task_id) - - total_time = time.perf_counter() - start_time - - # Summary - print() - print("=" * 60) - print(" SUMMARY") - print("=" * 60) - print( - f" Tasks: {len(task_entries)} total, " - f"{len(task_entries) - len(failed)} ok, {len(failed)} failed" - ) - print(f" Episodes: {total_episodes}") - print(f" Time: {total_time:.1f}s ({total_time / 60:.1f}m)") - if args.upload_s3: - n_uploaded = sum(1 for r in results if r.get("s3_path")) - print(f" S3: {n_uploaded} uploaded") - if args.register_sql: - n_registered = sum(1 for r in results if r.get("sql_registered")) - print(f" SQL: {n_registered} registered") - print(f" Progress: {args.progress_file}") - if failed: - print(f" Failed: {failed[:10]}{'...' if len(failed) > 10 else ''}") - print("=" * 60) - print() - - return 1 if failed else 0 - - -if __name__ == "__main__": - raise SystemExit(main())