Skip to content

Commit 3ba115f

Browse files
init2winit Teamcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 873892308
1 parent 8d9c777 commit 3ba115f

File tree

2 files changed

+136
-16
lines changed

2 files changed

+136
-16
lines changed

init2winit/dataset_lib/criteo_terabyte_dataset.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@
3737

3838
# Change to the path to raw dataset files.
3939
RAW_CRITEO1TB_FILE_PATH = ''
40-
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(dict(
41-
input_shape=(13 + 26,),
42-
train_size=4_195_197_692,
43-
# We assume the tie breaking example went to the validation set, because
44-
# the test set in the mlperf version has 89_137_318 examples.
45-
valid_size=89_137_319,
46-
test_size=89_137_318,
47-
))
40+
CRITEO1TB_DEFAULT_HPARAMS = config_dict.ConfigDict(
41+
dict(
42+
input_shape=(13 + 26,),
43+
train_size=4_195_197_692,
44+
# We assume the tie breaking example went to the validation set, because
45+
# the test set in the mlperf version has 89_137_318 examples.
46+
valid_size=89_137_319,
47+
test_size=89_137_318,
48+
)
49+
)
4850
CRITEO1TB_METADATA = {
4951
'apply_one_hot_in_loss': True,
5052
}
@@ -142,9 +144,9 @@ def criteo_tsv_reader(
142144
ds = ds.repeat()
143145
ds = ds.interleave(
144146
tf.data.TextLineDataset,
145-
cycle_length=128,
147+
cycle_length=64,
146148
block_length=batch_size // 8,
147-
num_parallel_calls=128,
149+
num_parallel_calls=64,
148150
deterministic=False)
149151
if is_training:
150152
ds = ds.shuffle(buffer_size=524_288 * 100, seed=data_shuffle_seed)
@@ -179,15 +181,12 @@ def _convert_to_numpy_iterator_fn(
179181
num_batches = num_batches_in_split
180182

181183
iterator = iter(tfds.as_numpy(tf_dataset))
182-
zeros_batch = None
184+
batch = None
183185
for _ in range(num_batches):
184186
try:
185187
batch = next(iterator)
186188
except StopIteration:
187-
if zeros_batch is None:
188-
zeros_batch = jax.tree.map(
189-
lambda x: np.zeros_like(x, dtype=x.dtype), batch)
190-
yield zeros_batch
189+
yield batch
191190
continue
192191
batch = data_utils.maybe_pad_batch(
193192
batch, desired_batch_size=per_host_eval_batch_size)
@@ -218,14 +217,25 @@ def get_criteo1tb(shuffle_rng,
218217
num_batches_to_prefetch = (hps.num_tf_data_prefetches
219218
if hps.num_tf_data_prefetches > 0 else tf.data.AUTOTUNE)
220219

220+
num_device_prefetches = hps.get('num_device_prefetches', 0)
221+
221222
train_dataset = criteo_tsv_reader(
222223
split='train',
223224
shuffle_rng=shuffle_rng,
224225
file_path=train_file_path,
225226
num_dense_features=hps.num_dense_features,
226227
batch_size=per_host_batch_size,
227228
num_batches_to_prefetch=num_batches_to_prefetch)
228-
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
229+
data_utils.log_rss('train dataset created')
230+
if num_device_prefetches > 0:
231+
train_iterator_fn = lambda: data_utils.prefetch_iterator(
232+
tfds.as_numpy(train_dataset), num_device_prefetches
233+
)
234+
data_utils.log_rss(
235+
f'using prefetching with {num_device_prefetches} in the train dataset'
236+
)
237+
else:
238+
train_iterator_fn = lambda: tfds.as_numpy(train_dataset)
229239
eval_train_dataset = criteo_tsv_reader(
230240
split='eval_train',
231241
shuffle_rng=None,
@@ -238,6 +248,7 @@ def get_criteo1tb(shuffle_rng,
238248
per_host_eval_batch_size=per_host_eval_batch_size,
239249
tf_dataset=eval_train_dataset,
240250
split_size=hps.train_size)
251+
data_utils.log_rss('eval_train dataset created')
241252
validation_dataset = criteo_tsv_reader(
242253
split='validation',
243254
shuffle_rng=None,
@@ -250,6 +261,7 @@ def get_criteo1tb(shuffle_rng,
250261
per_host_eval_batch_size=per_host_eval_batch_size,
251262
tf_dataset=validation_dataset,
252263
split_size=hps.valid_size)
264+
data_utils.log_rss('validation dataset created')
253265
test_dataset = criteo_tsv_reader(
254266
split='test',
255267
shuffle_rng=None,
@@ -262,6 +274,18 @@ def get_criteo1tb(shuffle_rng,
262274
per_host_eval_batch_size=per_host_eval_batch_size,
263275
tf_dataset=test_dataset,
264276
split_size=hps.test_size)
277+
data_utils.log_rss('test dataset created')
278+
279+
# Cache all the eval_train/validation/test iterators to avoid re-processing the
280+
# same data files.
281+
eval_train_iterator_fn = data_utils.CachedEvalIterator(
282+
eval_train_iterator_fn, 'eval_train'
283+
)
284+
validation_iterator_fn = data_utils.CachedEvalIterator(
285+
validation_iterator_fn, 'validation'
286+
)
287+
test_iterator_fn = data_utils.CachedEvalIterator(test_iterator_fn, 'test')
288+
265289
return data_utils.Dataset(
266290
train_iterator_fn,
267291
eval_train_iterator_fn,

init2winit/dataset_lib/data_utils.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
"""Common code used by different models."""
1717

1818
import collections
19+
import queue
20+
import resource
21+
import threading
22+
from typing import Iterator
1923

24+
from absl import logging
2025
import flax.linen as nn
2126
import jax
2227
from jax.nn import one_hot
@@ -33,6 +38,97 @@
3338
])
3439

3540

41+
def log_rss(msg: str):
42+
"""Logs the current memory usage and prints the given message."""
43+
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
44+
logging.info('%s — RSS: %.1f MB', msg, rss_mb)
45+
46+
47+
def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike],
48+
num_prefetch: int) -> Iterator[jax.typing.ArrayLike]:
49+
"""Wraps the given iterator with prefetching.
50+
51+
Args:
52+
source_iter: The iterator to wrap.
53+
num_prefetch: The number of items to prefetch.
54+
55+
Yields:
56+
Prefetched items from `source_iter`.
57+
"""
58+
buf = queue.Queue(maxsize=num_prefetch)
59+
sentinel = object() # Used to signal end of iterator
60+
61+
def producer():
62+
try:
63+
for item in source_iter:
64+
buf.put(item)
65+
except Exception as e: # pylint: disable=broad-except
66+
buf.put(e)
67+
buf.put(sentinel)
68+
69+
t = threading.Thread(target=producer, daemon=True)
70+
t.start()
71+
72+
while True:
73+
item = buf.get()
74+
if item is sentinel:
75+
return
76+
if isinstance(item, Exception):
77+
raise item
78+
yield item
79+
80+
81+
class CachedEvalIterator:
82+
"""Lazily caches eval batches, which are typically small enough to fit into host memory."""
83+
84+
def __init__(self, iterator_factory, split_name='eval'):
85+
self._factory = iterator_factory
86+
self._split_name = split_name
87+
self._cache = []
88+
self._iterator = None
89+
self._fully_cached = False
90+
91+
def __call__(self, num_batches=None):
92+
yielded = 0
93+
94+
limit = (
95+
len(self._cache)
96+
if num_batches is None
97+
else min(len(self._cache), num_batches)
98+
)
99+
for i in range(limit):
100+
yield self._cache[i]
101+
yielded += 1
102+
103+
if num_batches is not None and yielded >= num_batches:
104+
return
105+
106+
if self._fully_cached:
107+
return
108+
109+
if self._iterator is None:
110+
logging.info('Building %s cache lazily...', self._split_name)
111+
self._iterator = iter(self._factory(None))
112+
113+
for batch in self._iterator:
114+
self._cache.append(batch)
115+
yield batch
116+
yielded += 1
117+
if num_batches is not None and yielded >= num_batches:
118+
return
119+
120+
self._fully_cached = True
121+
self._factory = None
122+
self._iterator = None
123+
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
124+
logging.info(
125+
'%s cache complete: %d batches — RSS: %.1f MB',
126+
self._split_name,
127+
len(self._cache),
128+
rss_mb,
129+
)
130+
131+
36132
def iterator_as_numpy(iterator):
37133
for x in iterator:
38134
yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access

0 commit comments

Comments
 (0)