From bf0996fb2e26a85d0f1d292dfe1f44f34becd2c3 Mon Sep 17 00:00:00 2001 From: Aidan Gao Date: Mon, 2 Mar 2026 16:47:40 -0500 Subject: [PATCH] multi dataset weighting --- egomimic/rldb/zarr/zarr_dataset_multi.py | 60 ++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 3 deletions(-) diff --git a/egomimic/rldb/zarr/zarr_dataset_multi.py b/egomimic/rldb/zarr/zarr_dataset_multi.py index 409204d2..89e547f2 100644 --- a/egomimic/rldb/zarr/zarr_dataset_multi.py +++ b/egomimic/rldb/zarr/zarr_dataset_multi.py @@ -467,6 +467,7 @@ def __init__( mode="train", percent=0.1, valid_ratio=0.2, + weight: float = 1.0, **kwargs, ): """ @@ -475,8 +476,11 @@ def __init__( mode (str, optional): Split mode to use (e.g., "train", "valid"). Defaults to "train". percent (float, optional): Fraction of the dataset to use from each underlying dataset. Defaults to 0.1. valid_ratio (float, optional): Validation split ratio for datasets that support a train/valid split. + weight (float, optional): This dataset's sampling weight relative to its siblings when nested inside another MultiDataset. Defaults to 1.0 (even weighting). **kwargs: Additional keyword arguments passed to underlying dataset constructors if needed. """ + self.weight = weight + self.train_collections, self.valid_collections = split_dataset_names( datasets.keys(), valid_ratio=valid_ratio, seed=SEED ) @@ -509,6 +513,56 @@ def __init__( super().__init__() + def _get_sample_weights(self) -> list[float]: + """ + Compute per-sample weights aligned with self.index_map, resolving + weights recursively through nested MultiDatasets. + + Returns: + List of float weights, one per entry in self.index_map, in the + same order as the index_map. + """ + sample_weights: list[float] = [] + + for dataset in self.datasets.values(): + if isinstance(dataset, MultiDataset): + inner = np.array(dataset._get_sample_weights(), dtype=float) + child_w = dataset.weight + else: + inner = np.ones(len(dataset), dtype=float) + child_w = 1.0 + + inner_sum = inner.sum() + if inner_sum > 0: + inner = inner / inner_sum # normalize to sum-to-1 within child + inner = inner * child_w + + sample_weights.extend(inner.tolist()) + + return sample_weights + + def _generate_train_sampler( + self, + replacement: bool = True, + ) -> torch.utils.data.WeightedRandomSampler: + """ + Build a WeightedRandomSampler whose per-sample probabilities reflect + the dataset weights assigned at construction time, resolved recursively + through any nested MultiDatasets. + + Args: + replacement: Whether to sample with replacement. Defaults to True. + + Returns: + A torch.utils.data.WeightedRandomSampler instance. + """ + weights = self._get_sample_weights() + return torch.utils.data.WeightedRandomSampler( + weights=weights, + num_samples=len(self), + replacement=replacement, + ) + def __len__(self) -> int: return len(self.index_map) @@ -523,13 +577,13 @@ def __getitem__(self, idx): return data @classmethod - def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): + def _from_resolver(cls, resolver: EpisodeResolver, weight: float = 1.0, **kwargs): """ create a MultiDataset from an EpisodeResolver. Args: resolver (EpisodeResolver): The resolver instance to use for loading datasets. - embodiment: The embodiment identifier to use for resolving datasets. + weight (float, optional): Sampling weight for this dataset relative to its siblings when nested inside another MultiDataset. Defaults to 1.0. **kwargs: Keyword args forwarded to resolver (e.g., filters, sync_from_s3) and MultiDataset constructor (e.g., mode, percent, key_map, valid_ratio). @@ -549,7 +603,7 @@ def _from_resolver(cls, resolver: EpisodeResolver, **kwargs): else: resolved = resolver.resolve(filters=filters) - return cls(datasets=resolved, **kwargs) + return cls(datasets=resolved, weight=weight, **kwargs) class ZarrDataset(torch.utils.data.Dataset):