3737
3838# Change to the path to raw dataset files.
3939RAW_CRITEO1TB_FILE_PATH = ''
40- CRITEO1TB_DEFAULT_HPARAMS = config_dict .ConfigDict (dict (
41- input_shape = (13 + 26 ,),
42- train_size = 4_195_197_692 ,
43- # We assume the tie breaking example went to the validation set, because
44- # the test set in the mlperf version has 89_137_318 examples.
45- valid_size = 89_137_319 ,
46- test_size = 89_137_318 ,
47- ))
40+ CRITEO1TB_DEFAULT_HPARAMS = config_dict .ConfigDict (
41+ dict (
42+ input_shape = (13 + 26 ,),
43+ train_size = 4_195_197_692 ,
44+ # We assume the tie breaking example went to the validation set, because
45+ # the test set in the mlperf version has 89_137_318 examples.
46+ valid_size = 89_137_319 ,
47+ test_size = 89_137_318 ,
48+ )
49+ )
4850CRITEO1TB_METADATA = {
4951 'apply_one_hot_in_loss' : True ,
5052}
@@ -142,9 +144,9 @@ def criteo_tsv_reader(
142144 ds = ds .repeat ()
143145 ds = ds .interleave (
144146 tf .data .TextLineDataset ,
145- cycle_length = 128 ,
147+ cycle_length = 64 ,
146148 block_length = batch_size // 8 ,
147- num_parallel_calls = 128 ,
149+ num_parallel_calls = 64 ,
148150 deterministic = False )
149151 if is_training :
150152 ds = ds .shuffle (buffer_size = 524_288 * 100 , seed = data_shuffle_seed )
@@ -218,14 +220,25 @@ def get_criteo1tb(shuffle_rng,
218220 num_batches_to_prefetch = (hps .num_tf_data_prefetches
219221 if hps .num_tf_data_prefetches > 0 else tf .data .AUTOTUNE )
220222
223+ num_device_prefetches = hps .get ('num_device_prefetches' , 0 )
224+
221225 train_dataset = criteo_tsv_reader (
222226 split = 'train' ,
223227 shuffle_rng = shuffle_rng ,
224228 file_path = train_file_path ,
225229 num_dense_features = hps .num_dense_features ,
226230 batch_size = per_host_batch_size ,
227231 num_batches_to_prefetch = num_batches_to_prefetch )
228- train_iterator_fn = lambda : tfds .as_numpy (train_dataset )
232+ data_utils .log_rss ('train dataset created' )
233+ if num_device_prefetches > 0 :
234+ train_iterator_fn = lambda : data_utils .prefetch_iterator (
235+ tfds .as_numpy (train_dataset ), num_device_prefetches
236+ )
237+ data_utils .log_rss (
238+ f'using prefetching with { num_device_prefetches } in the train dataset'
239+ )
240+ else :
241+ train_iterator_fn = lambda : tfds .as_numpy (train_dataset )
229242 eval_train_dataset = criteo_tsv_reader (
230243 split = 'eval_train' ,
231244 shuffle_rng = None ,
@@ -238,6 +251,7 @@ def get_criteo1tb(shuffle_rng,
238251 per_host_eval_batch_size = per_host_eval_batch_size ,
239252 tf_dataset = eval_train_dataset ,
240253 split_size = hps .train_size )
254+ data_utils .log_rss ('eval_train dataset created' )
241255 validation_dataset = criteo_tsv_reader (
242256 split = 'validation' ,
243257 shuffle_rng = None ,
@@ -250,6 +264,7 @@ def get_criteo1tb(shuffle_rng,
250264 per_host_eval_batch_size = per_host_eval_batch_size ,
251265 tf_dataset = validation_dataset ,
252266 split_size = hps .valid_size )
267+ data_utils .log_rss ('validation dataset created' )
253268 test_dataset = criteo_tsv_reader (
254269 split = 'test' ,
255270 shuffle_rng = None ,
@@ -262,6 +277,18 @@ def get_criteo1tb(shuffle_rng,
262277 per_host_eval_batch_size = per_host_eval_batch_size ,
263278 tf_dataset = test_dataset ,
264279 split_size = hps .test_size )
280+ data_utils .log_rss ('test dataset created' )
281+
282+ # Cache all the eval_train/validation/test iterators to avoid re-processing the
283+ # same data files.
284+ eval_train_iterator_fn = data_utils .CachedEvalIterator (
285+ eval_train_iterator_fn , 'eval_train'
286+ )
287+ validation_iterator_fn = data_utils .CachedEvalIterator (
288+ validation_iterator_fn , 'validation'
289+ )
290+ test_iterator_fn = data_utils .CachedEvalIterator (test_iterator_fn , 'test' )
291+
265292 return data_utils .Dataset (
266293 train_iterator_fn ,
267294 eval_train_iterator_fn ,
0 commit comments