diff --git a/egomimic/algo/hpt.py b/egomimic/algo/hpt.py index 67fc4eb4..f8813161 100644 --- a/egomimic/algo/hpt.py +++ b/egomimic/algo/hpt.py @@ -971,7 +971,9 @@ def process_batch_for_training(self, batch): """ processed_batch = {} - for embodiment_id, _batch in batch.items(): + + for dataset_name, _batch in batch.items(): + embodiment_id = _batch["robot_name"][0].item() processed_batch[embodiment_id] = {} for key, value in _batch.items(): key_name = self.data_schematic.lerobot_key_to_keyname( @@ -1250,7 +1252,7 @@ def visualize_preds(self, predictions, batch): Returns: ims (np.ndarray): (B, H, W, 3) - images with actions drawn on top """ - embodiment_id = batch["embodiment"][0].item() + embodiment_id = batch["robot_name"][0].item() embodiment_name = get_embodiment(embodiment_id).lower() ac_key = self.ac_keys[embodiment_id] @@ -1430,7 +1432,7 @@ def _robomimic_to_hpt_data( data["is_6dof"] = self.is_6dof data["pad_mask"] = batch["pad_mask"] - data["embodiment"] = batch["embodiment"] + data["embodiment"] = batch["robot_name"] for aux_ac_key in aux_ac_keys: data[aux_ac_key] = batch[aux_ac_key] diff --git a/egomimic/hydra_configs/data/zarr_multi_test.yaml b/egomimic/hydra_configs/data/zarr_multi_test.yaml new file mode 100644 index 00000000..17a76a0b --- /dev/null +++ b/egomimic/hydra_configs/data/zarr_multi_test.yaml @@ -0,0 +1,37 @@ +_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper +train_datasets: + dataset1: + _target_: rldb.zarr.MultiDataset + datasets: + dataset1: + _target_: egomimic.rldb.zarr.ZarrDataset + Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr + action_horizon: 100 + dataset2: + _target_: egomimic.rldb.zarr.ZarrDataset + Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr + action_horizon: 100 + mode: train +valid_datasets: + dataset1: + _target_: rldb.zarr.MultiDataset + datasets: + dataset1: + _target_: egomimic.rldb.zarr.ZarrDataset + Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr + action_horizon: 100 + dataset2: + _target_: egomimic.rldb.zarr.ZarrDataset + Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr + action_horizon: 100 + mode: valid + +train_dataloader_params: + dataset1: + batch_size: 2 + num_workers: 10 + +valid_dataloader_params: + dataset1: + batch_size: 2 + num_workers: 10 \ No newline at end of file diff --git a/egomimic/hydra_configs/data/zarr_resolver.yaml b/egomimic/hydra_configs/data/zarr_resolver.yaml new file mode 100644 index 00000000..cd343d10 --- /dev/null +++ b/egomimic/hydra_configs/data/zarr_resolver.yaml @@ -0,0 +1,40 @@ +_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper +train_datasets: + dataset1: + _target_: egomimic.rldb.zarr.MultiDataset._from_resolver + resolver: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver + folder_path: + _target_: pathlib.Path + _args_: [/coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual] + embodiment: "eva_bimanual" + mode: "train" + valid_ratio: 0.2 + sync_from_s3: false + filters: + is_deleted: false + +valid_datasets: + dataset1: + _target_: egomimic.rldb.zarr.MultiDataset._from_resolver + resolver: + _target_: egomimic.rldb.zarr.zarr_dataset_multi.LocalEpisodeResolver + folder_path: + _target_: pathlib.Path + _args_: [/coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual] + embodiment: "eva_bimanual" + mode: "valid" + valid_ratio: 0.2 + sync_from_s3: false + filters: + is_deleted: false + +train_dataloader_params: + dataset1: + batch_size: 2 + num_workers: 10 + +valid_dataloader_params: + dataset1: + batch_size: 2 + num_workers: 10 \ No newline at end of file diff --git a/egomimic/hydra_configs/data/zarr_test.yaml b/egomimic/hydra_configs/data/zarr_test.yaml new file mode 100644 index 00000000..c2854536 --- /dev/null +++ b/egomimic/hydra_configs/data/zarr_test.yaml @@ -0,0 +1,23 @@ +_target_: egomimic.pl_utils.pl_data_utils.MultiDataModuleWrapper + +train_datasets: + dataset1: + _target_: egomimic.rldb.zarr.ZarrDataset + Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr + action_horizon: 100 + +valid_datasets: + dataset1: + _target_: egomimic.rldb.zarr.ZarrDataset + Episode_path: /coc/flash7/rco3/EgoVerse/egomimic/rldb/zarr/zarr/bimanual/1769461098408.zarr + action_horizon: 100 + +train_dataloader_params: + dataset1: + batch_size: 32 + num_workers: 10 + +valid_dataloader_params: + dataset1: + batch_size: 32 + num_workers: 10 \ No newline at end of file diff --git a/egomimic/hydra_configs/train.yaml b/egomimic/hydra_configs/train.yaml index 828a21b6..70fcc32d 100644 --- a/egomimic/hydra_configs/train.yaml +++ b/egomimic/hydra_configs/train.yaml @@ -61,6 +61,9 @@ data_schematic: # Dynamically fill in these shapes from the dataset embodiment: key_type: metadata_keys lerobot_key: metadata.embodiment + robot_name: + key_type: metadata_keys + lerobot_key: robot_name aria_bimanual: front_img_1: key_type: camera_keys @@ -74,6 +77,9 @@ data_schematic: # Dynamically fill in these shapes from the dataset embodiment: key_type: metadata_keys lerobot_key: metadata.embodiment + robot_name: + key_type: metadata_keys + lerobot_key: robot_name mecka_bimanual: front_img_1: key_type: camera_keys diff --git a/egomimic/pl_utils/pl_data_utils.py b/egomimic/pl_utils/pl_data_utils.py index 21bc708c..eca9a5e7 100644 --- a/egomimic/pl_utils/pl_data_utils.py +++ b/egomimic/pl_utils/pl_data_utils.py @@ -1,4 +1,4 @@ -from torch.utils.data import DataLoader, random_split, default_collate +from torch.utils.data import DataLoader, random_split, default_collate, ConcatDataset from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning import LightningDataModule from transformers import AutoTokenizer @@ -81,7 +81,8 @@ def train_dataloader(self): iterables = dict() for dataset_name, dataset in self.train_datasets.items(): dataset_params = self.train_dataloader_params.get(dataset_name, {}) - iterables[dataset.embodiment] = DataLoader( + iterables[dataset_name] = DataLoader( + dataset, shuffle=True, collate_fn=self.collate_fn, @@ -94,7 +95,8 @@ def val_dataloader(self): iterables = dict() for dataset_name, dataset in self.valid_datasets.items(): dataset_params = self.valid_dataloader_params.get(dataset_name, {}) - iterables[dataset.embodiment] = DataLoader( + iterables[dataset_name] = DataLoader( + dataset, shuffle=False, collate_fn=self.collate_fn, diff --git a/egomimic/rldb/utils.py b/egomimic/rldb/utils.py index 86638451..6e617c48 100644 --- a/egomimic/rldb/utils.py +++ b/egomimic/rldb/utils.py @@ -1154,6 +1154,7 @@ def __init__(self, schematic_dict, viz_img_key, norm_mode="zscore"): "lerobot_key": key_info["lerobot_key"], "shape": None, "embodiment": embodiment_id, + "robot_name": embodiment_id, } ) @@ -1222,7 +1223,7 @@ def infer_shapes_from_batch(self, batch): Updates: The 'shape' column in the DataFrame is updated to match the inferred shapes (stored as tuples). """ - embodiment_id = int(batch["metadata.embodiment"]) + embodiment_id = int(batch["metadata.robot_name"]) for key, tensor in batch.items(): if hasattr(tensor, "shape"): shape = tuple(tensor.shape) @@ -1295,71 +1296,109 @@ def infer_norm_from_dataset(self, dataset): returns: dictionary of means and stds for proprio and action keys """ norm_columns = [] - - embodiment = dataset.embodiment - if isinstance(embodiment, str): - embodiment = get_embodiment_id(embodiment) - norm_columns.extend(self.keys_of_type("proprio_keys")) norm_columns.extend(self.keys_of_type("action_keys")) - logger.info( - f"[NormStats] Starting norm inference for embodiment={embodiment}, " - f"{len(norm_columns)} columns" - ) + def _normalize_embodiment(value): + if value is None: + return None + if isinstance(value, str): + return get_embodiment_id(value) + return int(value) + + def _collect_embodiments(ds): + embodiments = set() + + if hasattr(ds, "robot_name"): + emb = _normalize_embodiment(ds.robot_name) + if emb is not None: + embodiments.add(emb) + + if hasattr(ds, "datasets"): + for sub_ds in ds.datasets.values(): + embodiments.update(_collect_embodiments(sub_ds)) - def get_zarr_data(ds, col): + if isinstance(ds, (list, tuple)): + for sub_ds in ds: + embodiments.update(_collect_embodiments(sub_ds)) + + return embodiments + + def get_zarr_data(ds, col, embodiment): if hasattr(ds, "episode_reader"): # ZarrDataset + ds_emb = _normalize_embodiment(getattr(ds, "robot_name", None)) + if ds_emb != embodiment: + return None if col in ds.episode_reader._store: return ds.episode_reader._store[col][:] return None - elif hasattr(ds, "datasets"): + if hasattr(ds, "datasets"): # MultiDataset wrapper data_list = [] for d in ds.datasets.values(): - res = get_zarr_data(d, col) + res = get_zarr_data(d, col, embodiment) + if res is not None: + data_list.append(res) + if data_list: + return np.concatenate(data_list, axis=0) + if isinstance(ds, (list, tuple)): + data_list = [] + for d in ds: + res = get_zarr_data(d, col, embodiment) if res is not None: data_list.append(res) if data_list: return np.concatenate(data_list, axis=0) return None - for column in norm_columns: - if not self.is_key_with_embodiment(column, embodiment): - continue - column_name = self.keyname_to_lerobot_key(column, embodiment) - logger.info(f"[NormStats] Processing column={column_name}") + embodiments = _collect_embodiments(dataset) + if not embodiments: + raise ValueError("Could not determine any embodiments from dataset to infer norms.") - column_data = get_zarr_data(dataset, column_name) + for embodiment in sorted(embodiments): + logger.info( + f"[NormStats] Starting norm inference for embodiment={embodiment}, " + f"{len(norm_columns)} columns" + ) - if column_data is None: - logger.warning(f"Skipping {column_name}, data not found given dataset type") - continue + for column in norm_columns: + if not self.is_key_with_embodiment(column, embodiment): + continue + column_name = self.keyname_to_lerobot_key(column, embodiment) + logger.info(f"[NormStats] Processing column={column_name}") - if column_data.ndim not in (2, 3): - raise ValueError( - f"Column {column} has shape {column_data.shape}, " - "expected 2 or 3 dims" - ) + column_data = get_zarr_data(dataset, column_name, embodiment) - mean = np.mean(column_data, axis=0) - std = np.std(column_data, axis=0) - minv = np.min(column_data, axis=0) - maxv = np.max(column_data, axis=0) - median = np.median(column_data, axis=0) - q1 = np.percentile(column_data, 1, axis=0) - q99 = np.percentile(column_data, 99, axis=0) + if column_data is None: + logger.warning( + f"Skipping {column_name}, data not found for embodiment={embodiment}" + ) + continue - self.norm_stats[embodiment][column] = { - "mean": torch.from_numpy(mean).float(), - "std": torch.from_numpy(std).float(), - "min": torch.from_numpy(minv).float(), - "max": torch.from_numpy(maxv).float(), - "median": torch.from_numpy(median).float(), - "quantile_1": torch.from_numpy(q1).float(), - "quantile_99": torch.from_numpy(q99).float(), - } + if column_data.ndim not in (2, 3): + raise ValueError( + f"Column {column} has shape {column_data.shape}, " + "expected 2 or 3 dims" + ) + + mean = np.mean(column_data, axis=0) + std = np.std(column_data, axis=0) + minv = np.min(column_data, axis=0) + maxv = np.max(column_data, axis=0) + median = np.median(column_data, axis=0) + q1 = np.percentile(column_data, 1, axis=0) + q99 = np.percentile(column_data, 99, axis=0) + + self.norm_stats[embodiment][column] = { + "mean": torch.from_numpy(mean).float(), + "std": torch.from_numpy(std).float(), + "min": torch.from_numpy(minv).float(), + "max": torch.from_numpy(maxv).float(), + "median": torch.from_numpy(median).float(), + "quantile_1": torch.from_numpy(q1).float(), + "quantile_99": torch.from_numpy(q99).float(), + } logger.info("[NormStats] Finished norm inference") diff --git a/egomimic/rldb/zarr/__init__.py b/egomimic/rldb/zarr/__init__.py index 89ee6670..c23d6f52 100644 --- a/egomimic/rldb/zarr/__init__.py +++ b/egomimic/rldb/zarr/__init__.py @@ -7,8 +7,10 @@ MultiDataset, ZarrDataset, ZarrEpisode, + LocalEpisodeResolver, + S3EpisodeResolver, ) -from egomimic.rldb.zarr.zarr_writer import ZarrWriter +#from egomimic.rldb.zarr.zarr_writer import ZarrWriter __all__ = [ "EpisodeResolver", @@ -16,4 +18,6 @@ "ZarrDataset", "ZarrEpisode", "ZarrWriter", + "LocalEpisodeResolver", + "S3EpisodeResolver", ] diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 5aaa4d50..2873ee65 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -429,8 +429,8 @@ def resolve( filtered_paths = self._get_local_filtered_paths(self.folder_path, filters) - valid_hashes = {hashes for _, hashes in filtered_paths} - if not valid_hashes: + valid_folder_names = {folder_name for _, folder_name in filtered_paths} + if not valid_folder_names: raise ValueError( "No valid collection names from local filtering: " "filters matched no episodes in the local directory." @@ -438,7 +438,7 @@ def resolve( datasets = self._load_zarr_datasets( search_path=self.folder_path, - valid_hashes=valid_hashes, + valid_folder_names=valid_folder_names, action_horizon=action_horizon, ) @@ -453,7 +453,6 @@ class MultiDataset(torch.utils.data.Dataset): """ def __init__(self, datasets, - embodiment, mode="train", percent=0.1, key_map=None, @@ -472,12 +471,6 @@ def __init__(self, self.datasets = datasets self.key_map = key_map - self.embodiment = get_embodiment_id(embodiment) - for dataset_name, dataset in self.datasets.items(): - assert dataset.embodiment == self.embodiment, ( - f"Dataset {dataset_name} has embodiment {dataset.embodiment}, expected {self.embodiment}." - ) - self.index_map = [] for dataset_name, dataset in self.datasets.items(): for local_idx in range(len(dataset)): @@ -526,7 +519,11 @@ def __getitem__(self, idx): if self.key_map and dataset_name in self.key_map: key_map = self.key_map[dataset_name] data = {key_map.get(k, k): v for k, v in data.items()} - + + robot_name = self.datasets[dataset_name].robot_name + data["metadata.robot_name"] = robot_name + data["embodiment"] = robot_name + data["robot_name"] = robot_name return data @classmethod @@ -592,7 +589,7 @@ def __init__( self._image_keys = None # Lazy-loaded set of JPEG-encoded keys self.init_episode() self.action_transform = ( - get_action_chunk_transform(self.embodiment) + get_action_chunk_transform(self.robot_name) if self.chunk_length is not None else None ) @@ -606,7 +603,7 @@ def init_episode(self): self.metadata = self.episode_reader.metadata self.total_frames = self.metadata["total_frames"] self.keys_dict = {k: (0, None) for k in self.episode_reader._collect_keys()} - self.embodiment = int(get_embodiment_id(self.metadata["robot_type"])) + self.robot_name = int(get_embodiment_id(self.metadata["robot_type"])) # Detect JPEG-encoded image keys from metadata self._image_keys = self._detect_image_keys() @@ -683,8 +680,10 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: - # Add embodiment id - data["metadata.embodiment"] = self.embodiment + # Add metadata + data["metadata.robot_name"] = self.robot_name + data["embodiment"] = self.robot_name + data["robot_name"] = self.robot_name return data