diff --git a/src/cell_load/data_modules/samplers.py b/src/cell_load/data_modules/samplers.py index 9ebb767..4beb400 100644 --- a/src/cell_load/data_modules/samplers.py +++ b/src/cell_load/data_modules/samplers.py @@ -6,6 +6,7 @@ import numpy as np from torch.utils.data import Sampler, Subset import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler from ..dataset import MetadataConcatDataset, PerturbationDataset from ..utils.data_utils import H5MetadataCache @@ -75,9 +76,14 @@ def __init__( # Create caches for all unique H5 files. self.metadata_caches = {} - for subset in self.dataset.datasets: - base_dataset: PerturbationDataset = subset.dataset - self.metadata_caches[base_dataset.h5_path] = base_dataset.metadata_cache + if isinstance(self.dataset, DistributedSampler): + for subset in self.dataset.dataset.datasets: + base_dataset: PerturbationDataset = subset.dataset + self.metadata_caches[base_dataset.h5_path] = base_dataset.metadata_cache + else : + for subset in self.dataset.datasets: + base_dataset: PerturbationDataset = subset.dataset + self.metadata_caches[base_dataset.h5_path] = base_dataset.metadata_cache # Create batches using the code-based grouping. self.sentences = self._create_sentences() @@ -253,10 +259,16 @@ def _create_sentences(self) -> list[list[int]]: """ global_offset = 0 all_batches = [] - for subset in self.dataset.datasets: - subset_batches = self._process_subset(global_offset, subset) - all_batches.extend(subset_batches) - global_offset += len(subset) + if isinstance(self.dataset, DistributedSampler): + for subset in self.dataset.dataset.datasets: + subset_batches = self._process_subset(global_offset, subset) + all_batches.extend(subset_batches) + global_offset += len(subset) + else: + for subset in self.dataset.datasets: + subset_batches = self._process_subset(global_offset, subset) + all_batches.extend(subset_batches) + global_offset += len(subset) np.random.shuffle(all_batches) return all_batches