diff --git a/egomimic/algo/hpt.py b/egomimic/algo/hpt.py index 252a9b28..85b081da 100644 --- a/egomimic/algo/hpt.py +++ b/egomimic/algo/hpt.py @@ -970,13 +970,13 @@ def process_batch_for_training(self, batch): embodiment: torch.Size([]) """ processed_batch = {} - for embodiment_name, _batch in batch.items(): embodiment_id = get_embodiment_id(embodiment_name) processed_batch[embodiment_id] = {} for key, value in _batch.items(): + key_name = self.data_schematic.zarr_key_to_keyname(key, embodiment_id) if key is not None: - processed_batch[embodiment_id][key] = value + processed_batch[embodiment_id][key_name] = value ac_key = self.ac_keys[embodiment_id] if len(processed_batch[embodiment_id][ac_key].shape) != 3: @@ -987,11 +987,12 @@ def process_batch_for_training(self, batch): processed_batch[embodiment_id]["pad_mask"] = torch.ones( B, S, 1, device=device ) + processed_batch[embodiment_id] = self.data_schematic.normalize_data( processed_batch[embodiment_id], embodiment_id ) processed_batch[embodiment_id]["embodiment"] = torch.tensor( - [embodiment_id], device=self.device, dtype=torch.int64 + [embodiment_id], device=self.device, dtype=torch.int64 ) return processed_batch @@ -1010,7 +1011,10 @@ def forward_training(self, batch): predictions = OrderedDict() hpt_batches = {} self.training_step += 1 - for embodiment_id, _batch in batch.items(): # TODO why don't we use batch with embodiment_name to keep things consistent + for ( + embodiment_id, + _batch, + ) in batch.items(): embodiment_name = get_embodiment(embodiment_id).lower() cam_keys = self.camera_keys[embodiment_id] proprio_keys = self.proprio_keys[embodiment_id] @@ -1251,7 +1255,7 @@ def visualize_preds(self, predictions, batch): Returns: ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top """ - + embodiment_id = batch["embodiment"][0].item() embodiment_name = get_embodiment(embodiment_id).lower() ac_key = self.ac_keys[embodiment_id] diff --git a/egomimic/hydra_configs/data/test_multi_zarr.yaml b/egomimic/hydra_configs/data/aria_bc_zarr.yaml similarity index 92% rename from egomimic/hydra_configs/data/test_multi_zarr.yaml rename to egomimic/hydra_configs/data/aria_bc_zarr.yaml index b98bc752..2f2bf186 100644 --- a/egomimic/hydra_configs/data/test_multi_zarr.yaml +++ b/egomimic/hydra_configs/data/aria_bc_zarr.yaml @@ -6,7 +6,7 @@ train_datasets: _target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver folder_path: /nethome/paphiwetsa3/flash/datasets/proc_zarr key_map: - front_img_1: #batch key + observations.images.front_img_1: #batch key key_type: camera_keys # key type zarr_key: front_img_1 actions_cartesian: @@ -24,7 +24,7 @@ valid_datasets: _target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver folder_path: /nethome/paphiwetsa3/flash/datasets/proc_zarr key_map: - front_img_1: #batch key + observations.images.front_img_1: #batch key key_type: camera_keys # key type zarr_key: front_img_1 actions_cartesian: diff --git a/egomimic/hydra_configs/data/eva_bc_zarr.yaml b/egomimic/hydra_configs/data/eva_bc_zarr.yaml new file mode 100644 index 00000000..3f858d4a --- /dev/null +++ b/egomimic/hydra_configs/data/eva_bc_zarr.yaml @@ -0,0 +1,107 @@ +_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper +train_datasets: + eva_bimanual: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset + datasets: + single_episode: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.ZarrDataset + Episode_path: /coc/flash7/scratch/egoverseDebugDatasets/eva/1767495035712.zarr + key_map: + observations.images.front_img_1: + key_type: camera_keys + zarr_key: images.front_1 + observations.images.right_wrist_img: + key_type: camera_keys + zarr_key: images.right_wrist + observations.images.left_wrist_img: + key_type: camera_keys + zarr_key: images.left_wrist + right.obs_ee_pose: + key_type: proprio_keys + zarr_key: right.obs_ee_pose + right.obs_gripper: + key_type: proprio_keys + zarr_key: right.gripper + left.obs_ee_pose: + key_type: proprio_keys + zarr_key: left.obs_ee_pose + left.obs_gripper: + key_type: proprio_keys + zarr_key: left.gripper + right.gripper: + key_type: action_keys + zarr_key: right.gripper + horizon: 45 + left.gripper: + key_type: action_keys + zarr_key: left.gripper + horizon: 45 + right.cmd_ee_pose: + key_type: action_keys + zarr_key: right.cmd_ee_pose + horizon: 45 + left.cmd_ee_pose: + key_type: action_keys + zarr_key: left.cmd_ee_pose + horizon: 45 + transform_list: + _target_: egomimic.rldb.zarr.action_chunk_transforms.build_eva_bimanual_transform_list + mode: total + +valid_datasets: + eva_bimanual: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.MultiDataset + datasets: + single_episode: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.ZarrDataset + Episode_path: /coc/flash7/scratch/egoverseDebugDatasets/eva/1767495035712.zarr + key_map: + observations.images.front_img_1 : + key_type: camera_keys + zarr_key: images.front_1 + observations.images.right_wrist_img: + key_type: camera_keys + zarr_key: images.right_wrist + observations.images.left_wrist_img: + key_type: camera_keys + zarr_key: images.left_wrist + right.obs_ee_pose: + key_type: proprio_keys + zarr_key: right.obs_ee_pose + right.obs_gripper: + key_type: proprio_keys + zarr_key: right.gripper + left.obs_ee_pose: + key_type: proprio_keys + zarr_key: left.obs_ee_pose + left.obs_gripper: + key_type: proprio_keys + zarr_key: left.gripper + right.gripper: + key_type: action_keys + zarr_key: right.gripper + horizon: 45 + left.gripper: + key_type: action_keys + zarr_key: left.gripper + horizon: 45 + right.cmd_ee_pose: + key_type: action_keys + zarr_key: right.cmd_ee_pose + horizon: 45 + left.cmd_ee_pose: + key_type: action_keys + zarr_key: left.cmd_ee_pose + horizon: 45 + transform_list: + _target_: egomimic.rldb.zarr.action_chunk_transforms.build_eva_bimanual_transform_list + mode: total + +train_dataloader_params: + eva_bimanual: + batch_size: 32 + num_workers: 10 +valid_dataloader_params: + eva_bimanual: + batch_size: 32 + num_workers: 10 diff --git a/egomimic/hydra_configs/data/zarr_test.yaml b/egomimic/hydra_configs/data/zarr_test.yaml deleted file mode 100644 index 92eb23c1..00000000 --- a/egomimic/hydra_configs/data/zarr_test.yaml +++ /dev/null @@ -1,60 +0,0 @@ -_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper - -train_datasets: - scale_bimanual: - _target_: egomimic.rldb.zarr.ZarrDataset - Episode_path: external/scale/scripts/datasets/2026-02-19-03-21-23-570038/697c1e6c0cac8cd3c4873844_episode_000000.zarr - key_map: - front_img_1: - key_type: camera_keys - zarr_key: observations.images.front_img_1 - ee_pose: - key_type: proprio_keys - zarr_key: observations.state.ee_pose - horizon: 100 - actions_cartesian: - key_type: action_keys - zarr_key: actions_ee_se3_world - horizon: 100 - actions_keypoints: - key_type: action_keys - zarr_key: actions_keypoint_world - horizon: 100 - actions_head_cartesian: - key_type: action_keys - zarr_key: actions_head_se3_world - horizon: 100 -valid_datasets: - scale_bimanual: - _target_: egomimic.rldb.zarr.ZarrDataset - Episode_path: external/scale/scripts/datasets/2026-02-19-03-21-23-570038/697c1e6c0cac8cd3c4873844_episode_000000.zarr - key_map: - front_img_1: - key_type: camera_keys - zarr_key: observations.images.front_img_1 - ee_pose: - key_type: proprio_keys - zarr_key: observations.state.ee_pose - horizon: 100 - actions_cartesian: - key_type: action_keys - zarr_key: actions_ee_se3_world - horizon: 100 - actions_keypoints: - key_type: action_keys - zarr_key: actions_keypoint_world - horizon: 100 - actions_head_cartesian: - key_type: action_keys - zarr_key: actions_head_se3_world - horizon: 100 - -train_dataloader_params: - dataset1: - batch_size: 32 - num_workers: 10 - -valid_dataloader_params: - dataset1: - batch_size: 32 - num_workers: 10 \ No newline at end of file diff --git a/egomimic/hydra_configs/model/hpt_bc_flow_eva.yaml b/egomimic/hydra_configs/model/hpt_bc_flow_eva.yaml index 1cdf8020..1e12d7e7 100644 --- a/egomimic/hydra_configs/model/hpt_bc_flow_eva.yaml +++ b/egomimic/hydra_configs/model/hpt_bc_flow_eva.yaml @@ -81,7 +81,7 @@ robomimic_model: crossattn_dim_head: 64 crossattn_modality_dropout: 0.1 modality_embed_dim: 256 - state_joint_positions: + state_ee_pose: _target_: egomimic.models.hpt_nets.MLPPolicyStem input_dim: 14 output_dim: 256 diff --git a/egomimic/hydra_configs/train_zarr.yaml b/egomimic/hydra_configs/train_zarr.yaml index a3487959..06c530a1 100644 --- a/egomimic/hydra_configs/train_zarr.yaml +++ b/egomimic/hydra_configs/train_zarr.yaml @@ -1,10 +1,10 @@ defaults: - - model: hpt_bc_flow_aria + - model: hpt_bc_flow_eva - paths: default - - trainer: ddp + - trainer: debug - debug: null - - logger: wandb - - data: test_multi_zarr + - logger: debug + - data: eva_bc_zarr.yaml - callbacks: checkpoints - override hydra/launcher: submitit - _self_ @@ -31,3 +31,88 @@ hydra: launch_params: gpus_per_node: 1 nodes: 1 + + +data_schematic: # Dynamically fill in these shapes from the dataset + _target_: egomimic.rldb.zarr.utils.DataSchematic + norm_mode: quantile + schematic_dict: + eva_bimanual: + front_img_1: #batch key + key_type: camera_keys # key type + zarr_key: observations.images.front_img_1 # dataset key + right_wrist_img: + key_type: camera_keys + zarr_key: observations.images.right_wrist_img + left_wrist_img: + key_type: camera_keys + zarr_key: observations.images.left_wrist_img + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose + joint_positions: + key_type: proprio_keys + zarr_key: observations.state.joint_positions + actions_joints: + key_type: action_keys + zarr_key: actions_joints + actions_cartesian: + key_type: action_keys + zarr_key: actions_cartesian + embodiment: + key_type: metadata_keys + zarr_key: metadata.embodiment + aria_bimanual: + front_img_1: + key_type: camera_keys + zarr_key: observations.images.front_img_1 + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose + actions_cartesian: + key_type: action_keys + zarr_key: actions_cartesian + embodiment: + key_type: metadata_keys + zarr_key: metadata.embodiment + mecka_bimanual: + front_img_1: + key_type: camera_keys + zarr_key: observations.images.front_img_1 + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose_cam + actions_cartesian: + key_type: action_keys + zarr_key: actions_ee_cartesian_cam + actions_keypoints: + key_type: action_keys + zarr_key: actions_ee_keypoints_world + actions_head_cartesian: + key_type: action_keys + zarr_key: actions_head_cartesian_world + embodiment: + key_type: metadata_keys + zarr_key: metadata.embodiment + scale_bimanual: + front_img_1: + key_type: camera_keys + zarr_key: observations.images.front_img_1 + ee_pose: + key_type: proprio_keys + zarr_key: observations.state.ee_pose + actions_cartesian: + key_type: action_keys + zarr_key: actions_cartesian + embodiment: + key_type: metadata_keys + zarr_key: metadata.embodiment + viz_img_key: + eva_bimanual: + front_img_1 + aria_bimanual: + front_img_1 + mecka_bimanual: + front_img_1 + scale_bimanual: + front_img_1 \ No newline at end of file diff --git a/egomimic/rldb/zarr/utils.py b/egomimic/rldb/zarr/utils.py index da9eaa97..a27a2ad4 100644 --- a/egomimic/rldb/zarr/utils.py +++ b/egomimic/rldb/zarr/utils.py @@ -7,6 +7,27 @@ logger = logging.getLogger(__name__) +import random +import math +import os + +from egomimic.rldb.zarr.zarr_dataset_multi import ZarrDataset, MultiDataset + + +def set_global_seed(seed: int = 42): + + random.seed(seed) # Python RNG + np.random.seed(seed) # NumPy RNG + torch.manual_seed(seed) # PyTorch CPU + torch.cuda.manual_seed(seed) # PyTorch GPU + torch.cuda.manual_seed_all(seed) + + os.environ["PYTHONHASHSEED"] = str(seed) + + +set_global_seed(42) + + class DataSchematic(object): def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"): """ @@ -77,8 +98,7 @@ def zarr_key_to_keyname(self, zarr_key, embodiment): str: Key name, e.g., "front_img_1". """ df_filtered = self.df[ - (self.df["zarr_key"] == zarr_key) - & (self.df["embodiment"] == embodiment) + (self.df["zarr_key"] == zarr_key) & (self.df["embodiment"] == embodiment) ] if df_filtered.empty: @@ -127,74 +147,127 @@ def infer_shapes_from_batch(self, batch): shape = (1,) else: shape = None - if key in self.df["key_name"].values: - self.df.loc[self.df["key_name"] == key, "shape"] = str(shape) + if key in self.df["zarr_key"].values: + self.df.loc[self.df["zarr_key"] == key, "shape"] = str(shape) self.shapes_infered = True - def infer_norm_from_dataset_zarr(self, dataset, dataset_name): - """ - dataset: huggingface dataset or zarr dataset - returns: dictionary of means and stds for proprio and action keys - """ - norm_columns = [] - - embodiment = dataset_name # TODO may need to clean this up to make the code nicer + def infer_norm_from_dataset( + self, + dataset, + dataset_name, + sample_frac: float = 0.10, + seed: int = 42, + max_samples: int | None = None, + log_every: int = 200, + include_all_key_map_keys: bool = False, + extra_aux_keys=(), + ): + embodiment = dataset_name if isinstance(embodiment, str): embodiment = get_embodiment_id(embodiment) - norm_columns.extend(self.keys_of_type("proprio_keys")) - norm_columns.extend(self.keys_of_type("action_keys")) + norm_keys = [] + norm_keys.extend(self.keys_of_type("proprio_keys")) + norm_keys.extend(self.keys_of_type("action_keys")) + norm_keys = [k for k in norm_keys if self.is_key_with_embodiment(k, embodiment)] + + if not norm_keys: + logger.warning( + f"[NormStats] No proprio/action keys for embodiment={embodiment}" + ) + return + + aux_keys = self.dataset_raw_norm_keys( + dataset, + key_types=("proprio_keys", "action_keys"), + extra_keys=extra_aux_keys, + include_all_key_map_keys=include_all_key_map_keys, + ) + + if not aux_keys: + logger.warning(f"[NormStats] No aux keys found for dataset={type(dataset)}") + return + def get_item_keys_from_any(ds, idx, keys): + if hasattr(ds, "get_item_keys"): + return ds.get_item_keys(idx, keys=keys) + + if hasattr(ds, "datasets") and hasattr(ds, "index_map"): + dataset_name_, local_idx = ds.index_map[idx] + child = ds.datasets[dataset_name_] + if not hasattr(child, "get_item_keys"): + raise RuntimeError( + f"Child dataset {type(child)} has no get_item_keys" + ) + return child.get_item_keys(local_idx, keys=keys) + + raise RuntimeError(f"Unsupported dataset type: {type(ds)}") + + N = len(dataset) + if N <= 0: + raise ValueError("Dataset is empty") + + n_samples = int(math.ceil(sample_frac * N)) + n_samples = max(1, min(n_samples, N)) + if max_samples is not None: + n_samples = min(n_samples, max_samples) + + rng = random.Random(seed) + sample_indices = rng.sample(range(N), k=n_samples) + + logger.info(f"[NormStats] embodiment={embodiment} norm_keys={norm_keys}") + logger.info(f"[NormStats] aux_keys={aux_keys}") logger.info( - f"[NormStats] Starting norm inference for embodiment={embodiment}, " - f"{len(norm_columns)} columns" + f"[NormStats] sampling {n_samples}/{N} (~{100 * sample_frac:.1f}%) indices" ) - def get_zarr_data(ds, col): - if hasattr(ds, "episode_reader"): - # ZarrDataset - if col in ds.episode_reader._store: - return ds.episode_reader._store[col][:] - return None - elif hasattr(ds, "datasets"): - # MultiDataset wrapper - data_list = [] - for d in ds.datasets.values(): - res = get_zarr_data(d, col) - if res is not None: - data_list.append(res) - if data_list: - return np.concatenate(data_list, axis=0) - return None + collected = {k: [] for k in norm_keys} + expected_shapes = {} - for column in norm_columns: - if not self.is_key_with_embodiment(column, embodiment): - continue - column_name = self.keyname_to_zarr_key(column, embodiment) # zarr key for retrieval - logger.info(f"[NormStats] Processing column={column_name}") + for i, idx in enumerate(sample_indices, 1): + item = get_item_keys_from_any(dataset, idx, keys=aux_keys) + for k in norm_keys: + item_key = self.keyname_to_zarr_key(k, embodiment) + if item_key not in item: + continue - column_data = get_zarr_data(dataset, column_name) + x = item[item_key] + if isinstance(x, torch.Tensor): + x = x.detach().cpu().numpy() + x = np.asarray(x) - if column_data is None: - logger.warning(f"Skipping {column_name}, data not found given dataset type") + if k not in expected_shapes: + expected_shapes[k] = x.shape + else: + if x.shape != expected_shapes[k]: + raise ValueError( + f"[NormStats] Shape mismatch for key '{k}': " + f"expected {expected_shapes[k]}, got {x.shape}. " + "Ensure padding/horizon is consistent." + ) + + collected[k].append(x) + + if log_every and (i % log_every == 0): + logger.info(f"[NormStats] processed {i}/{n_samples} samples") + + for k in norm_keys: + if not collected[k]: + logger.warning(f"[NormStats] No data collected for key={k}") continue - if column_data.ndim not in (2, 3): - raise ValueError( - f"Column {column} has shape {column_data.shape}, " - "expected 2 or 3 dims" - ) + X = np.stack(collected[k], axis=0) - mean = np.mean(column_data, axis=0) - std = np.std(column_data, axis=0) - minv = np.min(column_data, axis=0) - maxv = np.max(column_data, axis=0) - median = np.median(column_data, axis=0) - q1 = np.percentile(column_data, 1, axis=0) - q99 = np.percentile(column_data, 99, axis=0) + mean = np.mean(X, axis=0) + std = np.std(X, axis=0) + minv = np.min(X, axis=0) + maxv = np.max(X, axis=0) + median = np.median(X, axis=0) + q1 = np.percentile(X, 1, axis=0) + q99 = np.percentile(X, 99, axis=0) - self.norm_stats[embodiment][column] = { + self.norm_stats[embodiment][k] = { "mean": torch.from_numpy(mean).float(), "std": torch.from_numpy(std).float(), "min": torch.from_numpy(minv).float(), @@ -204,6 +277,10 @@ def get_zarr_data(ds, col): "quantile_99": torch.from_numpy(q99).float(), } + logger.info( + f"[NormStats] key={k} samples={X.shape[0]} stat_shape={mean.shape}" + ) + logger.info("[NormStats] Finished norm inference") def viz_img_key(self): @@ -393,4 +470,38 @@ def unnormalize_data(self, data, embodiment): else: denorm_data[key] = tensor - return denorm_data \ No newline at end of file + return denorm_data + + @staticmethod + def _iter_leaf_datasets(ds): + + if isinstance(ds, ZarrDataset): + yield ds + elif isinstance(ds, MultiDataset): + for child in ds.datasets.values(): + yield from DataSchematic._iter_leaf_datasets(child) + else: + yield ds + + @staticmethod + def _key_map_for_any(ds) -> dict: + km = getattr(ds, "key_map", None) + return km + + @staticmethod + def dataset_raw_norm_keys( + ds, + key_types=("proprio_keys", "action_keys"), + extra_keys=(), + include_all_key_map_keys=False, + ) -> list[str]: + out = set(extra_keys) + for leaf in DataSchematic._iter_leaf_datasets(ds): + km = DataSchematic._key_map_for_any(leaf) + if include_all_key_map_keys: + out |= set(km.keys()) + else: + for k, info in km.items(): + if info.get("key_type") in set(key_types): + out.add(k) + return sorted(out) diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index fab19ef5..26bcef10 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -32,6 +32,7 @@ import subprocess import tempfile from datasets import concatenate_datasets +from typing import Iterable from enum import Enum import simplejpeg # from action_chunk_transforms import Transform @@ -46,6 +47,7 @@ SEED = 42 + def split_dataset_names(dataset_names, valid_ratio=0.2, seed=SEED): """ Split a list of dataset names into train/valid sets. @@ -75,11 +77,14 @@ def split_dataset_names(dataset_names, valid_ratio=0.2, seed=SEED): valid = set(names[:n_valid]) train = set(names[n_valid:]) return train, valid + + class EpisodeResolver: """ Base class for episode resolution utilities. Provides shared static/class helpers; subclasses implement resolve(). """ + def __init__( self, folder_path: Path, @@ -91,7 +96,6 @@ def __init__( self.transform_list = transform_list def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]): - """ Loads multiple Zarr datasets from the specified folder path, filtering only those whose hashes are present in the valid_folder_names set. @@ -114,18 +118,20 @@ def _load_zarr_datasets(self, search_path: Path, valid_folder_names: set[str]): if name.endswith(".zarr"): name = name[: -len(".zarr")] if name not in valid_folder_names: - logger.info(f"{p} is not in the list of filtered paths") + logger.info(f"{p} is not in the list of filtered paths") skipped.append(p.name) - continue + continue try: - ds_obj = ZarrDataset(p, key_map=self.key_map, transform_list=self.transform_list) + ds_obj = ZarrDataset( + p, key_map=self.key_map, transform_list=self.transform_list + ) datasets[name] = ds_obj except Exception as e: logger.error(f"Failed to load dataset at {p}: {e}") skipped.append(p.name) - + return datasets - + @classmethod def _episode_already_present(cls, local_dir: Path, episode_hash: str) -> bool: direct = local_dir / episode_hash @@ -133,11 +139,11 @@ def _episode_already_present(cls, local_dir: Path, episode_hash: str) -> bool: return True - class S3EpisodeResolver(EpisodeResolver): """ Resolves episodes via SQL table and optionally syncs from S3. """ + def __init__( self, folder_path: Path, @@ -147,7 +153,7 @@ def __init__( transform_list: list | None = None, ): self.bucket_name = bucket_name - self.main_prefix = main_prefix + self.main_prefix = main_prefix super().__init__(folder_path, key_map=key_map, transform_list=transform_list) def resolve( @@ -187,7 +193,7 @@ def resolve( ) return datasets - + @staticmethod def _get_filtered_paths(filters: dict | None = None) -> list[tuple[str, str]]: """ @@ -221,9 +227,10 @@ def _get_filtered_paths(filters: dict | None = None) -> list[tuple[str, str]]: logger.info(f"Paths: {paths}") return paths - @classmethod - def _sync_s3_to_local(cls, bucket_name: str, s3_paths: list[tuple[str, str]], local_dir: Path): + def _sync_s3_to_local( + cls, bucket_name: str, s3_paths: list[tuple[str, str]], local_dir: Path + ): if not s3_paths: return @@ -319,10 +326,12 @@ def sync_from_filters( return filtered_paths + class LocalEpisodeResolver(EpisodeResolver): """ Resolves episodes from local Zarr stores, filtering via local metadata. """ + def __init__( self, folder_path: Path, @@ -392,7 +401,9 @@ def resolve( Outputs a dict of ZarrDatasets with relevant filters from local data. """ if sync_from_s3: - logger.warning("LocalEpisodeResolver does not sync from S3; ignoring sync_from_s3=True.") + logger.warning( + "LocalEpisodeResolver does not sync from S3; ignoring sync_from_s3=True." + ) filters = dict(filters) if filters is not None else {} filters.setdefault("is_deleted", False) @@ -407,25 +418,26 @@ def resolve( ) datasets = self._load_zarr_datasets( - search_path=self.folder_path, - valid_folder_names=valid_folder_names + search_path=self.folder_path, valid_folder_names=valid_folder_names ) return datasets - class MultiDataset(torch.utils.data.Dataset): """ - Self wrapping MultiDataset, can wrap zarr or multi dataset. + Self wrapping MultiDataset, can wrap zarr or multi dataset. """ - def __init__(self, + + def __init__( + self, datasets, mode="train", percent=0.1, valid_ratio=0.2, - **kwargs,): + **kwargs, + ): """ Args: datasets (dict): Dictionary mapping unique dataset hashes (str) to dataset objects. Datasets can be individual Zarr datasets or other multi-datasets; mixing different types is supported. @@ -469,7 +481,6 @@ def __init__(self, def __len__(self) -> int: return len(self.index_map) - def __getitem__(self, idx): dataset_name, local_idx = self.index_map[idx] data = self.datasets[dataset_name][local_idx] @@ -477,9 +488,9 @@ def __getitem__(self, idx): robot_name = self.datasets[dataset_name].embodiment data["metadata.robot_name"] = robot_name data["embodiment"] = robot_name - + return data - + @classmethod def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): """ @@ -507,7 +518,6 @@ def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): else: resolved = resolver.resolve(filters=filters) - return cls(datasets=resolved, **kwargs) @@ -531,10 +541,10 @@ def __init__( self.episode_path = Episode_path self.metadata = None self._image_keys = None # Lazy-loaded set of JPEG-encoded keys - self._json_keys = None # Lazy-loaded set of JSON-encoded keys + self._json_keys = None # Lazy-loaded set of JSON-encoded keys self._annotations = None self.init_episode() - + self.key_map = key_map self.transform = transform_list super().__init__() @@ -571,10 +581,7 @@ def _detect_json_keys(self) -> set[str]: Set of keys containing JSON payloads. """ features = self.metadata.get("features", {}) - return { - key for key, info in features.items() - if info.get("dtype") == "json" - } + return {key for key, info in features.items() if info.get("dtype") == "json"} @staticmethod def _decode_json_entry(value): @@ -618,7 +625,7 @@ def _annotation_text_for_frame(self, frame_idx: int) -> str: return str(ann.get("text", "")) return "" - def __len__(self) -> int: + def __len__(self) -> int: return self.total_frames def _pad_sequences(self, data, horizon: int | None) -> dict: @@ -656,7 +663,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: read_interval = (idx, None) read_dict = {zarr_key: read_interval} raw_data = self.episode_reader.read(read_dict) - self._pad_sequences(raw_data, horizon) # should be able to pad images + self._pad_sequences(raw_data, horizon) # should be able to pad images data[k] = raw_data[zarr_key] # Decode JPEG-encoded image data and normalize to [0, 1] @@ -664,7 +671,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: if zarr_key in self._image_keys: jpeg_bytes = data[k] # Decode JPEG bytes to numpy array (H, W, 3) - decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace='RGB') + decoded = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") # data[k] = torch.from_numpy(np.transpose(decoded, (2, 0, 1))).to(torch.float32) / 255.0 data[k] = np.transpose(decoded, (2, 0, 1)) / 255.0 elif zarr_key in self._json_keys: @@ -672,7 +679,7 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: data[k] = [self._decode_json_entry(v) for v in data[k]] else: data[k] = self._decode_json_entry(data[k]) - + # Convert all numpy arrays in data to torch tensors # TODO add the transform list code here @@ -686,7 +693,70 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: return data + def get_item_keys(self, idx: int, keys) -> dict[str, torch.Tensor]: + requested = self._normalize_keys_arg(keys) + out = {} + for k in requested: + if k not in self.key_map: + raise KeyError( + f"Unknown key '{k}'. Available keys: {list(self.key_map.keys())}" + ) + + zarr_key = self.key_map[k]["zarr_key"] + horizon = self.key_map[k].get("horizon", None) + + if horizon is not None: + end_idx = min(idx + horizon, self.total_frames) + interval = (idx, end_idx) + else: + interval = (idx, None) + + raw = self.episode_reader.read({zarr_key: interval}) + self._pad_sequences(raw, horizon) + val = raw[zarr_key] + + if zarr_key in self._image_keys: + if ( + isinstance(val, np.ndarray) + and val.dtype == object + and val.ndim == 1 + ): + decoded_seq = [] + for jpeg_bytes in val: + img = simplejpeg.decode_jpeg(jpeg_bytes, colorspace="RGB") + decoded_seq.append(np.transpose(img, (2, 0, 1)) / 255.0) + val = np.stack(decoded_seq, axis=0) + else: + img = simplejpeg.decode_jpeg(val, colorspace="RGB") + val = np.transpose(img, (2, 0, 1)) / 255.0 + + out[k] = val + + if self.transform: + for transform in self.transform or []: + out = transform.transform(out) + + for k, v in out.items(): + if isinstance(v, np.ndarray): + out[k] = torch.from_numpy(v).to(torch.float32) + + return out + + def _normalize_keys_arg(self, keys): + """ + Normalize keys argument: + None -> all dataset keys + str -> single key + Iterable[str] -> list of keys + """ + if keys is None: + return list(self.key_map.keys()) + if isinstance(keys, str): + return [keys] + if isinstance(keys, Iterable): + return list(keys) + raise TypeError(f"keys must be None, str, or iterable[str], got {type(keys)}") class ZarrEpisode: @@ -694,12 +764,14 @@ class ZarrEpisode: Lightweight wrapper around a single Zarr episode store. Designed for efficient PyTorch DataLoader usage with direct store access. """ + __slots__ = ( "_path", "_store", "metadata", "keys", ) + def __init__(self, path: str | Path): """ Initialize ZarrEpisode wrapper. @@ -707,11 +779,13 @@ def __init__(self, path: str | Path): path: Path to the .zarr episode directory """ self._path = Path(path) - self._store = zarr.open_group(str(self._path), mode='r') + self._store = zarr.open_group(str(self._path), mode="r") self.metadata = dict(self._store.attrs) self.keys = self.metadata["features"] - - def read(self, keys_with_ranges: dict[str, tuple[int, int | None]]) -> dict[str, np.ndarray]: + + def read( + self, keys_with_ranges: dict[str, tuple[int, int | None]] + ) -> dict[str, np.ndarray]: """ Read data for specified keys, each with their own index or range. Args: @@ -735,7 +809,7 @@ def read(self, keys_with_ranges: dict[str, tuple[int, int | None]]) -> dict[str, else: # Single frame read - use slicing to avoid 0D array issues with VariableLengthBytes # arr[start:start+1] gives us a 1D array, then [0] extracts the actual object - data = arr[start:start+1][0] + data = arr[start : start + 1][0] result[key] = data return result @@ -748,21 +822,25 @@ def _collect_keys(self) -> list[str]: if isinstance(self.keys, dict): return list(self.keys.keys()) return list(self.keys) + def __len__(self) -> int: """ Get total number of frames in the episode. Returns: Number of frames """ - return self.metadata['total_frames'] + return self.metadata["total_frames"] + def __repr__(self) -> str: """String representation of the episode.""" return f"ZarrEpisode(path={self._path}, frames={len(self)})" -if __name__ == '__main__': + +if __name__ == "__main__": from omegaconf import OmegaConf import hydra - dataset_cfg_path = '/nethome/paphiwetsa3/flash/projects/EgoVerse/egomimic/hydra_configs/data/test_multi_zarr.yaml' + + dataset_cfg_path = "/nethome/paphiwetsa3/flash/projects/EgoVerse/egomimic/hydra_configs/data/test_multi_zarr.yaml" # Using Hydra to load the dataset config dataset_cfg = OmegaConf.load(dataset_cfg_path) datamodule = hydra.utils.instantiate(dataset_cfg) diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index d28ebfcb..4e6ab79a 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -26,50 +26,6 @@ import os -# DEBUG -# os.environ["HYDRA_FULL_ERROR"] = '1' - -def create_data_schematic(zarr_data_cfg: DictConfig) -> DataSchematic: - schematic_dict = {} - - def populate_key_map(cfg, target_key="key_map", key_map={}): - """ - Populate key_map with the key_map configuration. - """ - if isinstance(cfg, DictConfig): - if target_key in cfg: - for k, v in cfg[target_key].items(): - key_map[k] = v - return - - for k in cfg.keys(): - v = cfg.get(k) - populate_key_map(v, target_key, key_map) - - elif isinstance(cfg, ListConfig): - for i, v in enumerate(cfg): - populate_key_map(v, target_key, key_map) - - for dataset_name in zarr_data_cfg.train_datasets: - dataset_cfg = zarr_data_cfg.train_datasets[dataset_name] - dataset_key_map = {} - populate_key_map(dataset_cfg, "key_map", dataset_key_map) - schematic_dict[dataset_name] = { - key: { - "key_type": value["key_type"], - "zarr_key": value["zarr_key"], - } - for key, value in dataset_key_map.items() - } - - viz_img_key = { - "eva_bimanual": "front_img_1", - "aria_bimanual": "front_img_1", - "mecka_bimanual": "front_img_1", - "scale_bimanual": "front_img_1", - } # TODO: figure out where to put viz keys - return DataSchematic(schematic_dict, viz_img_key, norm_mode="quantile") - @task_wrapper def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -87,8 +43,8 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: L.seed_everything(cfg.seed, workers=True) # log.info(f"Instantiating data schematic <{cfg.data_schematic._target_}>") - - data_schematic: DataSchematic = create_data_schematic(cfg.data) + + data_schematic: DataSchematic = hydra.utils.instantiate(cfg.data_schematic) # Modify dataset configs to include `data_schematic` dynamically at runtime train_datasets = {} @@ -117,7 +73,7 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: for dataset_name, dataset in datamodule.train_datasets.items(): log.info(f"Inferring shapes for dataset <{dataset_name}>") data_schematic.infer_shapes_from_batch(dataset[0]) - data_schematic.infer_norm_from_dataset_zarr(dataset, dataset_name) + data_schematic.infer_norm_from_dataset(dataset, dataset_name) # NOTE: We also pass the data_schematic_dict into the robomimic model's instatiation now that we've initialzied the shapes and norm stats. In theory, upon loading the PL checkpoint, it will remember this, but let's see. log.info(f"Instantiating model <{cfg.model._target_}>") @@ -198,7 +154,9 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: return metric_dict, object_dict -@hydra.main(version_base="1.3", config_path="./hydra_configs", config_name="train.yaml") +@hydra.main( + version_base="1.3", config_path="./hydra_configs", config_name="train_zarr.yaml" +) def main(cfg: DictConfig) -> Optional[float]: """Main entry point for training.