From 93707c5fd0d541945506e9f00919cad3bc0fab36 Mon Sep 17 00:00:00 2001 From: Aniketh Cheluva Date: Mon, 9 Mar 2026 18:04:29 -0400 Subject: [PATCH] added functionality to load precomputed norm stats --- egomimic/rldb/zarr/utils.py | 92 +++++++++++++++++++++++++------------ egomimic/trainHydra.py | 24 ++++++---- 2 files changed, 78 insertions(+), 38 deletions(-) diff --git a/egomimic/rldb/zarr/utils.py b/egomimic/rldb/zarr/utils.py index 43bc8ab6..8d53c210 100644 --- a/egomimic/rldb/zarr/utils.py +++ b/egomimic/rldb/zarr/utils.py @@ -151,6 +151,38 @@ def infer_shapes_from_batch(self, batch): self.shapes_infered = True + def load_norm_from_cache(self, cache_dir: str) -> None: + """ + Load precomputed norm stats from a cache directory (containing norm_stats.json). + Skips infer_norm_from_dataset when using this for future runs. + + Args: + cache_dir: Path to directory containing norm_stats.json and embodiment/key/ stat .pt files. + """ + cache_file = os.path.join(cache_dir, "norm_stats.json") + if not os.path.isfile(cache_file): + raise FileNotFoundError( + f"norm_stats.json not found at {cache_file}. " + "Ensure cache_dir is the directory containing norm_stats.json." + ) + with open(cache_file) as f: + data = json.load(f) + stats = data.get("stats", {}) + for emb_str, keys_dict in stats.items(): + embodiment = int(emb_str) + if embodiment not in self.embodiments: + self.embodiments.add(embodiment) + if embodiment not in self.norm_stats: + self.norm_stats[embodiment] = {} + for key_name, paths in keys_dict.items(): + self.norm_stats[embodiment][key_name] = { + k: torch.load(v, map_location="cpu", weights_only=True) + for k, v in paths.items() + } + logger.info( + f"[NormStats] Loaded precomputed stats for embodiment={embodiment} key={key_name}" + ) + def infer_norm_from_dataset( self, dataset, @@ -159,8 +191,8 @@ def infer_norm_from_dataset( seed: int = 42, max_samples: int | None = None, batch_size: int = 512, - num_workers: int = 10, - benchmark_dir: str | None = None, + num_workers: int = 4, + cache_dir: str | None = None, ): """ Args: @@ -175,12 +207,12 @@ def infer_norm_from_dataset( if isinstance(embodiment, str): embodiment = get_embodiment_id(embodiment) - benchmark_stats = None - if benchmark_dir is not None: - os.makedirs(benchmark_dir, exist_ok=True) - benchmark_file = os.path.join(benchmark_dir, "benchmark.json") - benchmark_stats = {} - benchmark_stats["stats"] = {} + cache_stats = None + if cache_dir is not None: + os.makedirs(cache_dir, exist_ok=True) + cache_file = os.path.join(cache_dir, "norm_stats.json") + cache_stats = {} + cache_stats["stats"] = {} norm_keys = [] norm_keys.extend(self.keys_of_type("proprio_keys", embodiment)) @@ -252,8 +284,8 @@ def infer_norm_from_dataset( loading_end_time = time.time() loading_time = loading_end_time - loading_start_time - if benchmark_stats is not None: - benchmark_stats["loading_time"] = loading_time + if cache_stats is not None: + cache_stats["loading_time"] = loading_time computing_start_time = time.time() for k in norm_keys: @@ -284,22 +316,22 @@ def infer_norm_from_dataset( ).float(), } - if benchmark_stats is not None: + if cache_stats is not None: os.makedirs( - os.path.join(benchmark_dir, str(embodiment), k), exist_ok=True + os.path.join(cache_dir, str(embodiment), k), exist_ok=True ) - mean_path = os.path.join(benchmark_dir, str(embodiment), k, "mean.pt") - std_path = os.path.join(benchmark_dir, str(embodiment), k, "std.pt") - min_path = os.path.join(benchmark_dir, str(embodiment), k, "min.pt") - max_path = os.path.join(benchmark_dir, str(embodiment), k, "max.pt") + mean_path = os.path.join(cache_dir, str(embodiment), k, "mean.pt") + std_path = os.path.join(cache_dir, str(embodiment), k, "std.pt") + min_path = os.path.join(cache_dir, str(embodiment), k, "min.pt") + max_path = os.path.join(cache_dir, str(embodiment), k, "max.pt") median_path = os.path.join( - benchmark_dir, str(embodiment), k, "median.pt" + cache_dir, str(embodiment), k, "median.pt" ) quantile_1_path = os.path.join( - benchmark_dir, str(embodiment), k, "quantile_1.pt" + cache_dir, str(embodiment), k, "quantile_1.pt" ) quantile_99_path = os.path.join( - benchmark_dir, str(embodiment), k, "quantile_99.pt" + cache_dir, str(embodiment), k, "quantile_99.pt" ) torch.save(self.norm_stats[embodiment][k]["mean"], mean_path) torch.save(self.norm_stats[embodiment][k]["std"], std_path) @@ -312,11 +344,11 @@ def infer_norm_from_dataset( torch.save( self.norm_stats[embodiment][k]["quantile_99"], quantile_99_path ) - if benchmark_stats["stats"].get(embodiment, None) is None: - benchmark_stats["stats"][embodiment] = {} - if benchmark_stats["stats"][embodiment].get(k, None) is None: - benchmark_stats["stats"][embodiment][k] = {} - benchmark_stats["stats"][embodiment][k] = { + if cache_stats["stats"].get(embodiment, None) is None: + cache_stats["stats"][embodiment] = {} + if cache_stats["stats"][embodiment].get(k, None) is None: + cache_stats["stats"][embodiment][k] = {} + cache_stats["stats"][embodiment][k] = { "mean": mean_path, "std": std_path, "min": min_path, @@ -332,16 +364,16 @@ def infer_norm_from_dataset( computing_end_time = time.time() computing_time = computing_end_time - computing_start_time - if benchmark_stats is not None: - benchmark_stats["computing_time"] = computing_time - benchmark_stats["frames"] = n_samples + if cache_stats is not None: + cache_stats["computing_time"] = computing_time + cache_stats["frames"] = n_samples logger.info( f"[NormStats] Finished norm inference, loading_time={loading_time:.2f}s, computing_time={computing_time:.2f}s" ) - if benchmark_stats is not None: - with open(benchmark_file, "w") as f: - json.dump(benchmark_stats, f, indent=4) + if cache_stats is not None: + with open(cache_file, "w") as f: + json.dump(cache_stats, f, indent=4) def viz_img_key(self): """ diff --git a/egomimic/trainHydra.py b/egomimic/trainHydra.py index 8e4540e1..2313a593 100644 --- a/egomimic/trainHydra.py +++ b/egomimic/trainHydra.py @@ -71,6 +71,11 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: data_schematic: DataSchematic = hydra.utils.instantiate(cfg.data_schematic) + norm_stats_path = OmegaConf.select(cfg, "norm_stats_path", default=None) + if norm_stats_path: + log.info(f"Loading precomputed norm stats from {norm_stats_path}") + data_schematic.load_norm_from_cache(norm_stats_path) + # Modify dataset configs to include `data_schematic` dynamically at runtime train_datasets = {} for dataset_name in cfg.data.train_datasets: @@ -111,14 +116,17 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: instantiate_copy.resolver.key_map = km norm_dataset = hydra.utils.instantiate(instantiate_copy) - data_schematic.infer_norm_from_dataset( - norm_dataset, - dataset_name, - sample_frac=0.005, - benchmark_dir=os.path.join( - cfg.trainer.default_root_dir, "benchmark_stats.json" - ), - ) + if not norm_stats_path: + data_schematic.infer_norm_from_dataset( + norm_dataset, + dataset_name, + sample_frac=0.005, + cache_dir=cfg.trainer.default_root_dir, + ) + else: + log.info( + f"Skipping norm inference for {dataset_name} (using precomputed stats)" + ) # NOTE: We also pass the data_schematic_dict into the robomimic model's instatiation now that we've initialzied the shapes and norm stats. In theory, upon loading the PL checkpoint, it will remember this, but let's see. log.info(f"Instantiating model <{cfg.model._target_}>")