|
34 | 34 | import tensorflow as tf |
35 | 35 | import tensorflow_datasets as tfds |
36 | 36 |
|
37 | | - |
38 | 37 | # Change to the path to raw dataset files. |
39 | 38 | RAW_CRITEO1TB_FILE_PATH = '' |
| 39 | +PREPROCESSED_CRITEO1TB_FILE_PATH = '' # pylint: disable=invalid-name |
40 | 40 | CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict( |
41 | 41 | dict( |
42 | 42 | input_shape=(13 + 26,), |
@@ -157,6 +157,94 @@ def criteo_tsv_reader( |
157 | 157 | return ds |
158 | 158 |
|
159 | 159 |
|
| 160 | +_ARRAYRECORD_FEATURE_SPEC = { |
| 161 | + 'inputs': tf.io.FixedLenFeature([13 + 26], tf.float32), |
| 162 | + 'targets': tf.io.FixedLenFeature([1], tf.float32), |
| 163 | +} |
| 164 | + |
| 165 | + |
| 166 | +@tf.function |
| 167 | +def _parse_arrayrecord_example_fn(serialized_examples): |
| 168 | + """Parse a batch of serialized tf.train.Examples from ArrayRecord.""" |
| 169 | + parsed = tf.io.parse_example(serialized_examples, _ARRAYRECORD_FEATURE_SPEC) |
| 170 | + return { |
| 171 | + 'inputs': parsed['inputs'], |
| 172 | + 'targets': tf.squeeze(parsed['targets'], axis=-1), |
| 173 | + } |
| 174 | + |
| 175 | + |
| 176 | +def criteo_arrayrecord_reader( |
| 177 | + split, shuffle_rng, file_path, batch_size, num_batches_to_prefetch |
| 178 | +): |
| 179 | + """Input reader for preprocessed Criteo ArrayRecord data. |
| 180 | +
|
| 181 | + Args: |
| 182 | + split: one of {'train', 'eval_train', 'validation', 'test'}. |
| 183 | + shuffle_rng: jax.random.PRNGKey for shuffling (train). |
| 184 | + file_path: glob pattern for .array_record files. |
| 185 | + batch_size: per-host batch size. |
| 186 | + num_batches_to_prefetch: number of batches to prefetch. |
| 187 | +
|
| 188 | + Returns: |
| 189 | + A tf.data.Dataset object. |
| 190 | + """ |
| 191 | + # Import here to avoid hard dependency for TSV-only users. |
| 192 | + if split not in ['train', 'eval_train', 'validation', 'test']: |
| 193 | + raise ValueError(f'Invalid split name {split}.') |
| 194 | + data_shuffle_seed = None |
| 195 | + |
| 196 | + is_training = split == 'train' |
| 197 | + if is_training: |
| 198 | + _, data_shuffle_seed = jax.random.split(shuffle_rng, 2) |
| 199 | + data_shuffle_seed = data_utils.convert_jax_to_tf_random_seed( |
| 200 | + data_shuffle_seed |
| 201 | + ) |
| 202 | + |
| 203 | + # Discover all matching files. |
| 204 | + all_files = sorted(tf.io.gfile.glob(file_path)) |
| 205 | + if not all_files: |
| 206 | + raise ValueError(f'No ArrayRecord files found matching: {file_path}') |
| 207 | + |
| 208 | + # Shard files across hosts. |
| 209 | + index = jax.process_index() |
| 210 | + num_hosts = jax.process_count() |
| 211 | + host_files = all_files[index::num_hosts] |
| 212 | + |
| 213 | + # Interleave per-file datasets, with batch+parse inside each file's |
| 214 | + # sub-pipeline. This is critical for performance: interleaving dense float |
| 215 | + # tensors (post-parse) is much faster than interleaving raw byte strings |
| 216 | + # and batching them later. |
| 217 | + file_ds = tf.data.Dataset.from_tensor_slices(host_files) |
| 218 | + if is_training: |
| 219 | + file_ds = file_ds.repeat() |
| 220 | + file_ds = file_ds.shuffle( |
| 221 | + buffer_size=2 * len(host_files), seed=data_shuffle_seed |
| 222 | + ) |
| 223 | + |
| 224 | + ds = file_ds.interleave( |
| 225 | + lambda f: ( |
| 226 | + ar_dataset.ArrayRecordDataset([f]) |
| 227 | + .batch( |
| 228 | + batch_size, |
| 229 | + drop_remainder=is_training, |
| 230 | + num_parallel_calls=tf.data.AUTOTUNE, |
| 231 | + deterministic=False, |
| 232 | + ) |
| 233 | + .map( |
| 234 | + _parse_arrayrecord_example_fn, |
| 235 | + num_parallel_calls=tf.data.AUTOTUNE, |
| 236 | + deterministic=False, |
| 237 | + ) |
| 238 | + ), |
| 239 | + cycle_length=64, |
| 240 | + block_length=batch_size // 8, |
| 241 | + num_parallel_calls=64, |
| 242 | + deterministic=False, |
| 243 | + ) |
| 244 | + ds = ds.prefetch(num_batches_to_prefetch) |
| 245 | + return ds |
| 246 | + |
| 247 | + |
160 | 248 | def _eval_numpy_iterator( |
161 | 249 | num_batches, per_host_eval_batch_size, tf_dataset, split_size |
162 | 250 | ): |
@@ -222,14 +310,145 @@ def get_criteo1tb(shuffle_rng, |
222 | 310 | per_host_eval_batch_size = eval_batch_size // process_count |
223 | 311 | per_host_batch_size = batch_size // process_count |
224 | 312 |
|
| 313 | + use_raw_tsv = hps.get('use_raw_tsv', False) |
| 314 | + num_batches_to_prefetch = ( |
| 315 | + hps.num_tf_data_prefetches |
| 316 | + if hps.num_tf_data_prefetches > 0 |
| 317 | + else tf.data.AUTOTUNE |
| 318 | + ) |
| 319 | + num_device_prefetches = hps.get('num_device_prefetches', 0) |
| 320 | + |
| 321 | + if use_raw_tsv: |
| 322 | + return _get_criteo1tb_tsv( |
| 323 | + shuffle_rng, |
| 324 | + per_host_batch_size, |
| 325 | + per_host_eval_batch_size, |
| 326 | + hps, |
| 327 | + num_batches_to_prefetch, |
| 328 | + num_device_prefetches, |
| 329 | + ) |
| 330 | + else: |
| 331 | + return _get_criteo1tb_arrayrecord( |
| 332 | + shuffle_rng, |
| 333 | + per_host_batch_size, |
| 334 | + per_host_eval_batch_size, |
| 335 | + hps, |
| 336 | + num_batches_to_prefetch, |
| 337 | + num_device_prefetches, |
| 338 | + ) |
| 339 | + |
| 340 | + |
| 341 | +def _get_criteo1tb_arrayrecord( |
| 342 | + shuffle_rng, |
| 343 | + per_host_batch_size, |
| 344 | + per_host_eval_batch_size, |
| 345 | + hps, |
| 346 | + num_batches_to_prefetch, |
| 347 | + num_device_prefetches, |
| 348 | +): |
| 349 | + """Load Criteo 1TB from preprocessed ArrayRecord files.""" |
| 350 | + base = hps.get('preprocessed_data_path', PREPROCESSED_CRITEO1TB_FILE_PATH) |
| 351 | + train_file_path = os.path.join(base, 'train', '*') |
| 352 | + validation_file_path = os.path.join( |
| 353 | + base, 'val_set_second_half_of_day23_not_used', '*' |
| 354 | + ) |
| 355 | + test_file_path = os.path.join(base, 'eval', '*') |
| 356 | + |
| 357 | + train_dataset = criteo_arrayrecord_reader( |
| 358 | + split='train', |
| 359 | + shuffle_rng=shuffle_rng, |
| 360 | + file_path=train_file_path, |
| 361 | + batch_size=per_host_batch_size, |
| 362 | + num_batches_to_prefetch=num_batches_to_prefetch, |
| 363 | + ) |
| 364 | + data_utils.log_rss('train arrayrecord dataset created') |
| 365 | + |
| 366 | + if num_device_prefetches > 0: |
| 367 | + train_iterator_fn = lambda: data_utils.prefetch_iterator( |
| 368 | + tfds.as_numpy(train_dataset), num_device_prefetches |
| 369 | + ) |
| 370 | + data_utils.log_rss( |
| 371 | + f'using prefetching with {num_device_prefetches} in the train dataset' |
| 372 | + ) |
| 373 | + else: |
| 374 | + train_iterator_fn = lambda: tfds.as_numpy(train_dataset) |
| 375 | + |
| 376 | + eval_train_dataset = criteo_arrayrecord_reader( |
| 377 | + split='eval_train', |
| 378 | + shuffle_rng=None, |
| 379 | + file_path=train_file_path, |
| 380 | + batch_size=per_host_eval_batch_size, |
| 381 | + num_batches_to_prefetch=num_batches_to_prefetch, |
| 382 | + ) |
| 383 | + eval_train_iterator_fn = functools.partial( |
| 384 | + _eval_numpy_iterator, |
| 385 | + per_host_eval_batch_size=per_host_eval_batch_size, |
| 386 | + tf_dataset=eval_train_dataset, |
| 387 | + split_size=hps.train_size, |
| 388 | + ) |
| 389 | + data_utils.log_rss('eval_train arrayrecord dataset created') |
| 390 | + |
| 391 | + validation_dataset = criteo_arrayrecord_reader( |
| 392 | + split='validation', |
| 393 | + shuffle_rng=None, |
| 394 | + file_path=validation_file_path, |
| 395 | + batch_size=per_host_eval_batch_size, |
| 396 | + num_batches_to_prefetch=num_batches_to_prefetch, |
| 397 | + ) |
| 398 | + validation_iterator_fn = functools.partial( |
| 399 | + _eval_numpy_iterator, |
| 400 | + per_host_eval_batch_size=per_host_eval_batch_size, |
| 401 | + tf_dataset=validation_dataset, |
| 402 | + split_size=hps.valid_size, |
| 403 | + ) |
| 404 | + data_utils.log_rss('validation arrayrecord dataset created') |
| 405 | + |
| 406 | + test_dataset = criteo_arrayrecord_reader( |
| 407 | + split='test', |
| 408 | + shuffle_rng=None, |
| 409 | + file_path=test_file_path, |
| 410 | + batch_size=per_host_eval_batch_size, |
| 411 | + num_batches_to_prefetch=num_batches_to_prefetch, |
| 412 | + ) |
| 413 | + test_iterator_fn = functools.partial( |
| 414 | + _eval_numpy_iterator, |
| 415 | + per_host_eval_batch_size=per_host_eval_batch_size, |
| 416 | + tf_dataset=test_dataset, |
| 417 | + split_size=hps.test_size, |
| 418 | + ) |
| 419 | + data_utils.log_rss('test arrayrecord dataset created') |
| 420 | + |
| 421 | + eval_train_iterator_fn = data_utils.CachedIteratorFactory( |
| 422 | + eval_train_iterator_fn(None), 'eval_train' |
| 423 | + ) |
| 424 | + validation_iterator_fn = data_utils.CachedIteratorFactory( |
| 425 | + validation_iterator_fn(None), 'validation' |
| 426 | + ) |
| 427 | + test_iterator_fn = data_utils.CachedIteratorFactory( |
| 428 | + test_iterator_fn(None), 'test' |
| 429 | + ) |
| 430 | + |
| 431 | + return data_utils.Dataset( |
| 432 | + train_iterator_fn, |
| 433 | + eval_train_iterator_fn, |
| 434 | + validation_iterator_fn, |
| 435 | + test_iterator_fn, |
| 436 | + ) |
| 437 | + |
| 438 | + |
| 439 | +def _get_criteo1tb_tsv( |
| 440 | + shuffle_rng, |
| 441 | + per_host_batch_size, |
| 442 | + per_host_eval_batch_size, |
| 443 | + hps, |
| 444 | + num_batches_to_prefetch, |
| 445 | + num_device_prefetches, |
| 446 | +): |
| 447 | + """Load Criteo 1TB from raw TSV files (legacy path).""" |
225 | 448 | train_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'train/*/*') |
226 | 449 | validation_file_path = os.path.join( |
227 | 450 | RAW_CRITEO1TB_FILE_PATH, 'val_set_second_half_of_day23_not_used/*') |
228 | 451 | test_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'eval/day_23/*') |
229 | | - num_batches_to_prefetch = (hps.num_tf_data_prefetches |
230 | | - if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE) |
231 | | - |
232 | | - num_device_prefetches = hps.get('num_device_prefetches', 0) |
233 | 452 |
|
234 | 453 | train_dataset = criteo_tsv_reader( |
235 | 454 | split='train', |
|
0 commit comments