3737
3838# Change to the path to raw dataset files.
3939RAW_CRITEO1TB_FILE_PATH = ''
40- CRITEO1TB_DEFAULT_HPARAMS = config_dict .ConfigDict (dict (
41- input_shape = (13 + 26 ,),
42- train_size = 4_195_197_692 ,
43- # We assume the tie breaking example went to the validation set, because
44- # the test set in the mlperf version has 89_137_318 examples.
45- valid_size = 89_137_319 ,
46- test_size = 89_137_318 ,
47- ))
40+ CRITEO1TB_DEFAULT_HPARAMS = config_dict .ConfigDict (
41+ dict (
42+ input_shape = (13 + 26 ,),
43+ train_size = 4_195_197_692 ,
44+ # We assume the tie breaking example went to the validation set, because
45+ # the test set in the mlperf version has 89_137_318 examples.
46+ valid_size = 89_137_319 ,
47+ test_size = 89_137_318 ,
48+ )
49+ )
4850CRITEO1TB_METADATA = {
4951 'apply_one_hot_in_loss' : True ,
5052}
@@ -142,9 +144,9 @@ def criteo_tsv_reader(
142144 ds = ds .repeat ()
143145 ds = ds .interleave (
144146 tf .data .TextLineDataset ,
145- cycle_length = 128 ,
147+ cycle_length = 64 ,
146148 block_length = batch_size // 8 ,
147- num_parallel_calls = 128 ,
149+ num_parallel_calls = 64 ,
148150 deterministic = False )
149151 if is_training :
150152 ds = ds .shuffle (buffer_size = 524_288 * 100 , seed = data_shuffle_seed )
@@ -179,15 +181,12 @@ def _convert_to_numpy_iterator_fn(
179181 num_batches = num_batches_in_split
180182
181183 iterator = iter (tfds .as_numpy (tf_dataset ))
182- zeros_batch = None
184+ batch = None
183185 for _ in range (num_batches ):
184186 try :
185187 batch = next (iterator )
186188 except StopIteration :
187- if zeros_batch is None :
188- zeros_batch = jax .tree .map (
189- lambda x : np .zeros_like (x , dtype = x .dtype ), batch )
190- yield zeros_batch
189+ yield batch
191190 continue
192191 batch = data_utils .maybe_pad_batch (
193192 batch , desired_batch_size = per_host_eval_batch_size )
@@ -218,14 +217,25 @@ def get_criteo1tb(shuffle_rng,
218217 num_batches_to_prefetch = (hps .num_tf_data_prefetches
219218 if hps .num_tf_data_prefetches > 0 else tf .data .AUTOTUNE )
220219
220+ num_device_prefetches = hps .get ('num_device_prefetches' , 0 )
221+
221222 train_dataset = criteo_tsv_reader (
222223 split = 'train' ,
223224 shuffle_rng = shuffle_rng ,
224225 file_path = train_file_path ,
225226 num_dense_features = hps .num_dense_features ,
226227 batch_size = per_host_batch_size ,
227228 num_batches_to_prefetch = num_batches_to_prefetch )
228- train_iterator_fn = lambda : tfds .as_numpy (train_dataset )
229+ data_utils .log_rss ('train dataset created' )
230+ if num_device_prefetches > 0 :
231+ train_iterator_fn = lambda : data_utils .prefetch_iterator (
232+ tfds .as_numpy (train_dataset ), num_device_prefetches
233+ )
234+ data_utils .log_rss (
235+ f'using prefetching with { num_device_prefetches } in the train dataset'
236+ )
237+ else :
238+ train_iterator_fn = lambda : tfds .as_numpy (train_dataset )
229239 eval_train_dataset = criteo_tsv_reader (
230240 split = 'eval_train' ,
231241 shuffle_rng = None ,
@@ -238,6 +248,7 @@ def get_criteo1tb(shuffle_rng,
238248 per_host_eval_batch_size = per_host_eval_batch_size ,
239249 tf_dataset = eval_train_dataset ,
240250 split_size = hps .train_size )
251+ data_utils .log_rss ('eval_train dataset created' )
241252 validation_dataset = criteo_tsv_reader (
242253 split = 'validation' ,
243254 shuffle_rng = None ,
@@ -250,6 +261,7 @@ def get_criteo1tb(shuffle_rng,
250261 per_host_eval_batch_size = per_host_eval_batch_size ,
251262 tf_dataset = validation_dataset ,
252263 split_size = hps .valid_size )
264+ data_utils .log_rss ('validation dataset created' )
253265 test_dataset = criteo_tsv_reader (
254266 split = 'test' ,
255267 shuffle_rng = None ,
@@ -262,6 +274,18 @@ def get_criteo1tb(shuffle_rng,
262274 per_host_eval_batch_size = per_host_eval_batch_size ,
263275 tf_dataset = test_dataset ,
264276 split_size = hps .test_size )
277+ data_utils .log_rss ('test dataset created' )
278+
279+ # Cache all the eval_train/validation/test iterators to avoid re-processing the
280+ # same data files.
281+ eval_train_iterator_fn = data_utils .CachedEvalIterator (
282+ eval_train_iterator_fn , 'eval_train'
283+ )
284+ validation_iterator_fn = data_utils .CachedEvalIterator (
285+ validation_iterator_fn , 'validation'
286+ )
287+ test_iterator_fn = data_utils .CachedEvalIterator (test_iterator_fn , 'test' )
288+
265289 return data_utils .Dataset (
266290 train_iterator_fn ,
267291 eval_train_iterator_fn ,
0 commit comments