Skip to content

Commit a222b5b

Browse files
committed
fix
1 parent 505a75c commit a222b5b

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

src/twinkle/dataloader/dataloader.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import copy
3+
import os
34
import warnings
45
from functools import partial
56
from typing import Callable, Optional, Type, Union
@@ -56,6 +57,7 @@ def __init__(self,
5657
self._skip_samples = 0
5758
self._base_batch_sampler = None
5859
self._base_sampler = None
60+
self._retry_sampler_seed = self._resolve_retry_sampler_seed()
5961
self._set_work_init_fn()
6062

6163
def _set_work_init_fn(self):
@@ -65,6 +67,17 @@ def _set_work_init_fn(self):
6567
num_workers=num_workers,
6668
rank=self.device_mesh.data_rank if self.device_mesh else 0)
6769

70+
@staticmethod
71+
def _resolve_retry_sampler_seed() -> int:
72+
env_seed = os.environ.get('TWINKLE_SEED')
73+
if env_seed is not None:
74+
return int(env_seed)
75+
try:
76+
from twinkle.infra import _seed
77+
return int(_seed)
78+
except Exception:
79+
return 42
80+
6881
@remote_function()
6982
def __len__(self):
7083
self._lazy_init_dataloader()
@@ -145,6 +158,7 @@ def _rebuild_sampler_stack(self):
145158
self._base_sampler,
146159
self.dataset,
147160
max_retries=self.max_retries,
161+
seed=self._retry_sampler_seed,
148162
)
149163
self.dataloader.batch_sampler = DeviceMeshSampler(
150164
batch_sampler,
@@ -158,4 +172,5 @@ def _rebuild_sampler_stack(self):
158172
self.dataset,
159173
max_retries=self.max_retries,
160174
skip_samples=self._skip_samples,
175+
seed=self._retry_sampler_seed,
161176
)

src/twinkle/dataloader/retry_sampler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@ class RetrySampler(Sampler):
1414
max_retries: The maximum number of retries.
1515
"""
1616

17-
def __init__(self, original_sampler: Sampler, dataset: Dataset, max_retries=20, skip_samples: int = 0):
17+
def __init__(self,
18+
original_sampler: Sampler,
19+
dataset: Dataset,
20+
max_retries=20,
21+
skip_samples: int = 0,
22+
seed: int = 42):
1823
self.original_sampler = original_sampler
1924
self.dataset = dataset
2025
self.max_retries = max_retries
2126
self.skip_samples = skip_samples
27+
self.seed = int(seed)
2228

2329
def __iter__(self):
2430
emitted = 0
@@ -48,9 +54,9 @@ def __iter__(self):
4854
if emitted >= target_total:
4955
return
5056

51-
for idx in np.random.RandomState().permutation(len(self.dataset)).tolist():
57+
for idx in np.random.RandomState(self.seed).permutation(len(self.dataset)).tolist():
5258
if emitted >= target_total:
53-
raise StopIteration
59+
return
5460
for _ in range(self.max_retries):
5561
try:
5662
# Skip None values and raises

0 commit comments

Comments
 (0)