Skip to content

Commit 4521339

Browse files
Ahmed Khaledcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 876590170
1 parent 05bec4b commit 4521339

File tree

1 file changed

+224
-5
lines changed

1 file changed

+224
-5
lines changed

init2winit/dataset_lib/criteo_terabyte_dataset.py

Lines changed: 224 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
import tensorflow as tf
3535
import tensorflow_datasets as tfds
3636

37-
3837
# Change to the path to raw dataset files.
3938
RAW_CRITEO1TB_FILE_PATH = ''
39+
PREPROCESSED_CRITEO1TB_FILE_PATH = '' # pylint: disable=invalid-name
4040
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(
4141
dict(
4242
input_shape=(13 + 26,),
@@ -157,6 +157,94 @@ def criteo_tsv_reader(
157157
return ds
158158

159159

160+
_ARRAYRECORD_FEATURE_SPEC = {
161+
'inputs': tf.io.FixedLenFeature([13 + 26], tf.float32),
162+
'targets': tf.io.FixedLenFeature([1], tf.float32),
163+
}
164+
165+
166+
@tf.function
167+
def _parse_arrayrecord_example_fn(serialized_examples):
168+
"""Parse a batch of serialized tf.train.Examples from ArrayRecord."""
169+
parsed = tf.io.parse_example(serialized_examples, _ARRAYRECORD_FEATURE_SPEC)
170+
return {
171+
'inputs': parsed['inputs'],
172+
'targets': tf.squeeze(parsed['targets'], axis=-1),
173+
}
174+
175+
176+
def criteo_arrayrecord_reader(
177+
split, shuffle_rng, file_path, batch_size, num_batches_to_prefetch
178+
):
179+
"""Input reader for preprocessed Criteo ArrayRecord data.
180+
181+
Args:
182+
split: one of {'train', 'eval_train', 'validation', 'test'}.
183+
shuffle_rng: jax.random.PRNGKey for shuffling (train).
184+
file_path: glob pattern for .array_record files.
185+
batch_size: per-host batch size.
186+
num_batches_to_prefetch: number of batches to prefetch.
187+
188+
Returns:
189+
A tf.data.Dataset object.
190+
"""
191+
# Import here to avoid hard dependency for TSV-only users.
192+
if split not in ['train', 'eval_train', 'validation', 'test']:
193+
raise ValueError(f'Invalid split name {split}.')
194+
data_shuffle_seed = None
195+
196+
is_training = split == 'train'
197+
if is_training:
198+
_, data_shuffle_seed = jax.random.split(shuffle_rng, 2)
199+
data_shuffle_seed = data_utils.convert_jax_to_tf_random_seed(
200+
data_shuffle_seed
201+
)
202+
203+
# Discover all matching files.
204+
all_files = sorted(tf.io.gfile.glob(file_path))
205+
if not all_files:
206+
raise ValueError(f'No ArrayRecord files found matching: {file_path}')
207+
208+
# Shard files across hosts.
209+
index = jax.process_index()
210+
num_hosts = jax.process_count()
211+
host_files = all_files[index::num_hosts]
212+
213+
# Interleave per-file datasets, with batch+parse inside each file's
214+
# sub-pipeline. This is critical for performance: interleaving dense float
215+
# tensors (post-parse) is much faster than interleaving raw byte strings
216+
# and batching them later.
217+
file_ds = tf.data.Dataset.from_tensor_slices(host_files)
218+
if is_training:
219+
file_ds = file_ds.repeat()
220+
file_ds = file_ds.shuffle(
221+
buffer_size=2 * len(host_files), seed=data_shuffle_seed
222+
)
223+
224+
ds = file_ds.interleave(
225+
lambda f: (
226+
ar_dataset.ArrayRecordDataset([f])
227+
.batch(
228+
batch_size,
229+
drop_remainder=is_training,
230+
num_parallel_calls=tf.data.AUTOTUNE,
231+
deterministic=False,
232+
)
233+
.map(
234+
_parse_arrayrecord_example_fn,
235+
num_parallel_calls=tf.data.AUTOTUNE,
236+
deterministic=False,
237+
)
238+
),
239+
cycle_length=64,
240+
block_length=batch_size // 8,
241+
num_parallel_calls=64,
242+
deterministic=False,
243+
)
244+
ds = ds.prefetch(num_batches_to_prefetch)
245+
return ds
246+
247+
160248
def _eval_numpy_iterator(
161249
num_batches, per_host_eval_batch_size, tf_dataset, split_size
162250
):
@@ -222,14 +310,145 @@ def get_criteo1tb(shuffle_rng,
222310
per_host_eval_batch_size = eval_batch_size // process_count
223311
per_host_batch_size = batch_size // process_count
224312

313+
use_raw_tsv = hps.get('use_raw_tsv', False)
314+
num_batches_to_prefetch = (
315+
hps.num_tf_data_prefetches
316+
if hps.num_tf_data_prefetches > 0
317+
else tf.data.AUTOTUNE
318+
)
319+
num_device_prefetches = hps.get('num_device_prefetches', 0)
320+
321+
if use_raw_tsv:
322+
return _get_criteo1tb_tsv(
323+
shuffle_rng,
324+
per_host_batch_size,
325+
per_host_eval_batch_size,
326+
hps,
327+
num_batches_to_prefetch,
328+
num_device_prefetches,
329+
)
330+
else:
331+
return _get_criteo1tb_arrayrecord(
332+
shuffle_rng,
333+
per_host_batch_size,
334+
per_host_eval_batch_size,
335+
hps,
336+
num_batches_to_prefetch,
337+
num_device_prefetches,
338+
)
339+
340+
341+
def _get_criteo1tb_arrayrecord(
342+
shuffle_rng,
343+
per_host_batch_size,
344+
per_host_eval_batch_size,
345+
hps,
346+
num_batches_to_prefetch,
347+
num_device_prefetches,
348+
):
349+
"""Load Criteo 1TB from preprocessed ArrayRecord files."""
350+
base = hps.get('preprocessed_data_path', PREPROCESSED_CRITEO1TB_FILE_PATH)
351+
train_file_path = os.path.join(base, 'train', '*')
352+
validation_file_path = os.path.join(
353+
base, 'val_set_second_half_of_day23_not_used', '*'
354+
)
355+
test_file_path = os.path.join(base, 'eval', '*')
356+
357+
train_dataset = criteo_arrayrecord_reader(
358+
split='train',
359+
shuffle_rng=shuffle_rng,
360+
file_path=train_file_path,
361+
batch_size=per_host_batch_size,
362+
num_batches_to_prefetch=num_batches_to_prefetch,
363+
)
364+
data_utils.log_rss('train arrayrecord dataset created')
365+
366+
if num_device_prefetches > 0:
367+
train_iterator_fn = lambda: data_utils.prefetch_iterator(
368+
tfds.as_numpy(train_dataset), num_device_prefetches
369+
)
370+
data_utils.log_rss(
371+
f'using prefetching with {num_device_prefetches} in the train dataset'
372+
)
373+
else:
374+
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
375+
376+
eval_train_dataset = criteo_arrayrecord_reader(
377+
split='eval_train',
378+
shuffle_rng=None,
379+
file_path=train_file_path,
380+
batch_size=per_host_eval_batch_size,
381+
num_batches_to_prefetch=num_batches_to_prefetch,
382+
)
383+
eval_train_iterator_fn = functools.partial(
384+
_eval_numpy_iterator,
385+
per_host_eval_batch_size=per_host_eval_batch_size,
386+
tf_dataset=eval_train_dataset,
387+
split_size=hps.train_size,
388+
)
389+
data_utils.log_rss('eval_train arrayrecord dataset created')
390+
391+
validation_dataset = criteo_arrayrecord_reader(
392+
split='validation',
393+
shuffle_rng=None,
394+
file_path=validation_file_path,
395+
batch_size=per_host_eval_batch_size,
396+
num_batches_to_prefetch=num_batches_to_prefetch,
397+
)
398+
validation_iterator_fn = functools.partial(
399+
_eval_numpy_iterator,
400+
per_host_eval_batch_size=per_host_eval_batch_size,
401+
tf_dataset=validation_dataset,
402+
split_size=hps.valid_size,
403+
)
404+
data_utils.log_rss('validation arrayrecord dataset created')
405+
406+
test_dataset = criteo_arrayrecord_reader(
407+
split='test',
408+
shuffle_rng=None,
409+
file_path=test_file_path,
410+
batch_size=per_host_eval_batch_size,
411+
num_batches_to_prefetch=num_batches_to_prefetch,
412+
)
413+
test_iterator_fn = functools.partial(
414+
_eval_numpy_iterator,
415+
per_host_eval_batch_size=per_host_eval_batch_size,
416+
tf_dataset=test_dataset,
417+
split_size=hps.test_size,
418+
)
419+
data_utils.log_rss('test arrayrecord dataset created')
420+
421+
eval_train_iterator_fn = data_utils.CachedIteratorFactory(
422+
eval_train_iterator_fn(None), 'eval_train'
423+
)
424+
validation_iterator_fn = data_utils.CachedIteratorFactory(
425+
validation_iterator_fn(None), 'validation'
426+
)
427+
test_iterator_fn = data_utils.CachedIteratorFactory(
428+
test_iterator_fn(None), 'test'
429+
)
430+
431+
return data_utils.Dataset(
432+
train_iterator_fn,
433+
eval_train_iterator_fn,
434+
validation_iterator_fn,
435+
test_iterator_fn,
436+
)
437+
438+
439+
def _get_criteo1tb_tsv(
440+
shuffle_rng,
441+
per_host_batch_size,
442+
per_host_eval_batch_size,
443+
hps,
444+
num_batches_to_prefetch,
445+
num_device_prefetches,
446+
):
447+
"""Load Criteo 1TB from raw TSV files (legacy path)."""
225448
train_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'train/*/*')
226449
validation_file_path = os.path.join(
227450
RAW_CRITEO1TB_FILE_PATH, 'val_set_second_half_of_day23_not_used/*')
228451
test_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'eval/day_23/*')
229-
num_batches_to_prefetch = (hps.num_tf_data_prefetches
230-
if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE)
231-
232-
num_device_prefetches = hps.get('num_device_prefetches', 0)
233452

234453
train_dataset = criteo_tsv_reader(
235454
split='train',

0 commit comments

Comments
 (0)