11# Copyright (c) ModelScope Contributors. All rights reserved.
22import copy
3+ import os
34import warnings
45from functools import partial
56from typing import Callable , Optional , Type , Union
@@ -56,6 +57,7 @@ def __init__(self,
5657 self ._skip_samples = 0
5758 self ._base_batch_sampler = None
5859 self ._base_sampler = None
60+ self ._retry_sampler_seed = self ._resolve_retry_sampler_seed ()
5961 self ._set_work_init_fn ()
6062
6163 def _set_work_init_fn (self ):
@@ -65,6 +67,17 @@ def _set_work_init_fn(self):
6567 num_workers = num_workers ,
6668 rank = self .device_mesh .data_rank if self .device_mesh else 0 )
6769
70+ @staticmethod
71+ def _resolve_retry_sampler_seed () -> int :
72+ env_seed = os .environ .get ('TWINKLE_SEED' )
73+ if env_seed is not None :
74+ return int (env_seed )
75+ try :
76+ from twinkle .infra import _seed
77+ return int (_seed )
78+ except Exception :
79+ return 42
80+
6881 @remote_function ()
6982 def __len__ (self ):
7083 self ._lazy_init_dataloader ()
@@ -145,6 +158,7 @@ def _rebuild_sampler_stack(self):
145158 self ._base_sampler ,
146159 self .dataset ,
147160 max_retries = self .max_retries ,
161+ seed = self ._retry_sampler_seed ,
148162 )
149163 self .dataloader .batch_sampler = DeviceMeshSampler (
150164 batch_sampler ,
@@ -158,4 +172,5 @@ def _rebuild_sampler_stack(self):
158172 self .dataset ,
159173 max_retries = self .max_retries ,
160174 skip_samples = self ._skip_samples ,
175+ seed = self ._retry_sampler_seed ,
161176 )
0 commit comments