Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions egomimic/rldb/zarr/zarr_dataset_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def __init__(
mode="train",
percent=0.1,
valid_ratio=0.2,
weight: float = 1.0,
**kwargs,
):
"""
Expand All @@ -475,8 +476,11 @@ def __init__(
mode (str, optional): Split mode to use (e.g., "train", "valid"). Defaults to "train".
percent (float, optional): Fraction of the dataset to use from each underlying dataset. Defaults to 0.1.
valid_ratio (float, optional): Validation split ratio for datasets that support a train/valid split.
weight (float, optional): This dataset's sampling weight relative to its siblings when nested inside another MultiDataset. Defaults to 1.0 (even weighting).
**kwargs: Additional keyword arguments passed to underlying dataset constructors if needed.
"""
self.weight = weight

self.train_collections, self.valid_collections = split_dataset_names(
datasets.keys(), valid_ratio=valid_ratio, seed=SEED
)
Expand Down Expand Up @@ -509,6 +513,56 @@ def __init__(

super().__init__()

def _get_sample_weights(self) -> list[float]:
"""
Compute per-sample weights aligned with self.index_map, resolving
weights recursively through nested MultiDatasets.

Returns:
List of float weights, one per entry in self.index_map, in the
same order as the index_map.
"""
sample_weights: list[float] = []

for dataset in self.datasets.values():
if isinstance(dataset, MultiDataset):
inner = np.array(dataset._get_sample_weights(), dtype=float)
child_w = dataset.weight
else:
inner = np.ones(len(dataset), dtype=float)
child_w = 1.0

inner_sum = inner.sum()
if inner_sum > 0:
inner = inner / inner_sum # normalize to sum-to-1 within child
inner = inner * child_w

sample_weights.extend(inner.tolist())

return sample_weights

def _generate_train_sampler(
self,
replacement: bool = True,
) -> torch.utils.data.WeightedRandomSampler:
"""
Build a WeightedRandomSampler whose per-sample probabilities reflect
the dataset weights assigned at construction time, resolved recursively
through any nested MultiDatasets.

Args:
replacement: Whether to sample with replacement. Defaults to True.

Returns:
A torch.utils.data.WeightedRandomSampler instance.
"""
weights = self._get_sample_weights()
return torch.utils.data.WeightedRandomSampler(
weights=weights,
num_samples=len(self),
replacement=replacement,
)

def __len__(self) -> int:
return len(self.index_map)

Expand All @@ -523,13 +577,13 @@ def __getitem__(self, idx):
return data

@classmethod
def _from_resolver(cls, resolver: EpisodeResolver, **kwargs):
def _from_resolver(cls, resolver: EpisodeResolver, weight: float = 1.0, **kwargs):
"""
create a MultiDataset from an EpisodeResolver.

Args:
resolver (EpisodeResolver): The resolver instance to use for loading datasets.
embodiment: The embodiment identifier to use for resolving datasets.
weight (float, optional): Sampling weight for this dataset relative to its siblings when nested inside another MultiDataset. Defaults to 1.0.
**kwargs: Keyword args forwarded to resolver (e.g., filters,
sync_from_s3) and MultiDataset constructor (e.g., mode, percent,
key_map, valid_ratio).
Expand All @@ -549,7 +603,7 @@ def _from_resolver(cls, resolver: EpisodeResolver, **kwargs):
else:
resolved = resolver.resolve(filters=filters)

return cls(datasets=resolved, **kwargs)
return cls(datasets=resolved, weight=weight, **kwargs)


class ZarrDataset(torch.utils.data.Dataset):
Expand Down