From 2cd723692cb1bd4e9a92af830879730a0897a263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9F=B3=E5=95=B8=E5=B3=B0?= <34809315+InvictusL@users.noreply.github.com> Date: Fri, 19 Sep 2025 11:38:13 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A7=A3=E5=86=B3=E6=97=A0=E6=B3=95=E5=A4=9A?= =?UTF-8?q?=E5=8D=A1=E5=88=86=E5=B8=83=E5=BC=8F=E8=AE=AD=E7=BB=83=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加对DistributedSampler的识别和拆解,支持分布式数据加载,保证模型能够正常多卡训练(如 state) --- src/cell_load/data_modules/samplers.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/cell_load/data_modules/samplers.py b/src/cell_load/data_modules/samplers.py index 9ebb767..4beb400 100644 --- a/src/cell_load/data_modules/samplers.py +++ b/src/cell_load/data_modules/samplers.py @@ -6,6 +6,7 @@ import numpy as np from torch.utils.data import Sampler, Subset import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler from ..dataset import MetadataConcatDataset, PerturbationDataset from ..utils.data_utils import H5MetadataCache @@ -75,9 +76,14 @@ def __init__( # Create caches for all unique H5 files. self.metadata_caches = {} - for subset in self.dataset.datasets: - base_dataset: PerturbationDataset = subset.dataset - self.metadata_caches[base_dataset.h5_path] = base_dataset.metadata_cache + if isinstance(self.dataset, DistributedSampler): + for subset in self.dataset.dataset.datasets: + base_dataset: PerturbationDataset = subset.dataset + self.metadata_caches[base_dataset.h5_path] = base_dataset.metadata_cache + else : + for subset in self.dataset.datasets: + base_dataset: PerturbationDataset = subset.dataset + self.metadata_caches[base_dataset.h5_path] = base_dataset.metadata_cache # Create batches using the code-based grouping. self.sentences = self._create_sentences() @@ -253,10 +259,16 @@ def _create_sentences(self) -> list[list[int]]: """ global_offset = 0 all_batches = [] - for subset in self.dataset.datasets: - subset_batches = self._process_subset(global_offset, subset) - all_batches.extend(subset_batches) - global_offset += len(subset) + if isinstance(self.dataset, DistributedSampler): + for subset in self.dataset.dataset.datasets: + subset_batches = self._process_subset(global_offset, subset) + all_batches.extend(subset_batches) + global_offset += len(subset) + else: + for subset in self.dataset.datasets: + subset_batches = self._process_subset(global_offset, subset) + all_batches.extend(subset_batches) + global_offset += len(subset) np.random.shuffle(all_batches) return all_batches