Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions egomimic/rldb/zarr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,38 @@ def infer_shapes_from_batch(self, batch):

self.shapes_infered = True

def load_norm_from_cache(self, cache_dir: str) -> None:
"""
Load precomputed norm stats from a cache directory (containing norm_stats.json).
Skips infer_norm_from_dataset when using this for future runs.

Args:
cache_dir: Path to directory containing norm_stats.json and embodiment/key/ stat .pt files.
"""
cache_file = os.path.join(cache_dir, "norm_stats.json")
if not os.path.isfile(cache_file):
raise FileNotFoundError(
f"norm_stats.json not found at {cache_file}. "
"Ensure cache_dir is the directory containing norm_stats.json."
)
with open(cache_file) as f:
data = json.load(f)
stats = data.get("stats", {})
for emb_str, keys_dict in stats.items():
embodiment = int(emb_str)
if embodiment not in self.embodiments:
self.embodiments.add(embodiment)
if embodiment not in self.norm_stats:
self.norm_stats[embodiment] = {}
for key_name, paths in keys_dict.items():
self.norm_stats[embodiment][key_name] = {
k: torch.load(v, map_location="cpu", weights_only=True)
for k, v in paths.items()
}
logger.info(
f"[NormStats] Loaded precomputed stats for embodiment={embodiment} key={key_name}"
)

def infer_norm_from_dataset(
self,
dataset,
Expand All @@ -159,8 +191,8 @@ def infer_norm_from_dataset(
seed: int = 42,
max_samples: int | None = None,
batch_size: int = 512,
num_workers: int = 10,
benchmark_dir: str | None = None,
num_workers: int = 4,
cache_dir: str | None = None,
):
"""
Args:
Expand All @@ -175,12 +207,12 @@ def infer_norm_from_dataset(
if isinstance(embodiment, str):
embodiment = get_embodiment_id(embodiment)

benchmark_stats = None
if benchmark_dir is not None:
os.makedirs(benchmark_dir, exist_ok=True)
benchmark_file = os.path.join(benchmark_dir, "benchmark.json")
benchmark_stats = {}
benchmark_stats["stats"] = {}
cache_stats = None
if cache_dir is not None:
os.makedirs(cache_dir, exist_ok=True)
cache_file = os.path.join(cache_dir, "norm_stats.json")
cache_stats = {}
cache_stats["stats"] = {}

norm_keys = []
norm_keys.extend(self.keys_of_type("proprio_keys", embodiment))
Expand Down Expand Up @@ -252,8 +284,8 @@ def infer_norm_from_dataset(

loading_end_time = time.time()
loading_time = loading_end_time - loading_start_time
if benchmark_stats is not None:
benchmark_stats["loading_time"] = loading_time
if cache_stats is not None:
cache_stats["loading_time"] = loading_time

computing_start_time = time.time()
for k in norm_keys:
Expand Down Expand Up @@ -284,22 +316,22 @@ def infer_norm_from_dataset(
).float(),
}

if benchmark_stats is not None:
if cache_stats is not None:
os.makedirs(
os.path.join(benchmark_dir, str(embodiment), k), exist_ok=True
os.path.join(cache_dir, str(embodiment), k), exist_ok=True
)
mean_path = os.path.join(benchmark_dir, str(embodiment), k, "mean.pt")
std_path = os.path.join(benchmark_dir, str(embodiment), k, "std.pt")
min_path = os.path.join(benchmark_dir, str(embodiment), k, "min.pt")
max_path = os.path.join(benchmark_dir, str(embodiment), k, "max.pt")
mean_path = os.path.join(cache_dir, str(embodiment), k, "mean.pt")
std_path = os.path.join(cache_dir, str(embodiment), k, "std.pt")
min_path = os.path.join(cache_dir, str(embodiment), k, "min.pt")
max_path = os.path.join(cache_dir, str(embodiment), k, "max.pt")
median_path = os.path.join(
benchmark_dir, str(embodiment), k, "median.pt"
cache_dir, str(embodiment), k, "median.pt"
)
quantile_1_path = os.path.join(
benchmark_dir, str(embodiment), k, "quantile_1.pt"
cache_dir, str(embodiment), k, "quantile_1.pt"
)
quantile_99_path = os.path.join(
benchmark_dir, str(embodiment), k, "quantile_99.pt"
cache_dir, str(embodiment), k, "quantile_99.pt"
)
torch.save(self.norm_stats[embodiment][k]["mean"], mean_path)
torch.save(self.norm_stats[embodiment][k]["std"], std_path)
Expand All @@ -312,11 +344,11 @@ def infer_norm_from_dataset(
torch.save(
self.norm_stats[embodiment][k]["quantile_99"], quantile_99_path
)
if benchmark_stats["stats"].get(embodiment, None) is None:
benchmark_stats["stats"][embodiment] = {}
if benchmark_stats["stats"][embodiment].get(k, None) is None:
benchmark_stats["stats"][embodiment][k] = {}
benchmark_stats["stats"][embodiment][k] = {
if cache_stats["stats"].get(embodiment, None) is None:
cache_stats["stats"][embodiment] = {}
if cache_stats["stats"][embodiment].get(k, None) is None:
cache_stats["stats"][embodiment][k] = {}
cache_stats["stats"][embodiment][k] = {
"mean": mean_path,
"std": std_path,
"min": min_path,
Expand All @@ -332,16 +364,16 @@ def infer_norm_from_dataset(

computing_end_time = time.time()
computing_time = computing_end_time - computing_start_time
if benchmark_stats is not None:
benchmark_stats["computing_time"] = computing_time
benchmark_stats["frames"] = n_samples
if cache_stats is not None:
cache_stats["computing_time"] = computing_time
cache_stats["frames"] = n_samples

logger.info(
f"[NormStats] Finished norm inference, loading_time={loading_time:.2f}s, computing_time={computing_time:.2f}s"
)
if benchmark_stats is not None:
with open(benchmark_file, "w") as f:
json.dump(benchmark_stats, f, indent=4)
if cache_stats is not None:
with open(cache_file, "w") as f:
json.dump(cache_stats, f, indent=4)

def viz_img_key(self):
"""
Expand Down
24 changes: 16 additions & 8 deletions egomimic/trainHydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

data_schematic: DataSchematic = hydra.utils.instantiate(cfg.data_schematic)

norm_stats_path = OmegaConf.select(cfg, "norm_stats_path", default=None)
if norm_stats_path:
log.info(f"Loading precomputed norm stats from {norm_stats_path}")
data_schematic.load_norm_from_cache(norm_stats_path)

# Modify dataset configs to include `data_schematic` dynamically at runtime
train_datasets = {}
for dataset_name in cfg.data.train_datasets:
Expand Down Expand Up @@ -111,14 +116,17 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:

instantiate_copy.resolver.key_map = km
norm_dataset = hydra.utils.instantiate(instantiate_copy)
data_schematic.infer_norm_from_dataset(
norm_dataset,
dataset_name,
sample_frac=0.005,
benchmark_dir=os.path.join(
cfg.trainer.default_root_dir, "benchmark_stats.json"
),
)
if not norm_stats_path:
data_schematic.infer_norm_from_dataset(
norm_dataset,
dataset_name,
sample_frac=0.005,
cache_dir=cfg.trainer.default_root_dir,
)
else:
log.info(
f"Skipping norm inference for {dataset_name} (using precomputed stats)"
)

# NOTE: We also pass the data_schematic_dict into the robomimic model's instatiation now that we've initialzied the shapes and norm stats. In theory, upon loading the PL checkpoint, it will remember this, but let's see.
log.info(f"Instantiating model <{cfg.model._target_}>")
Expand Down