Skip to content

Commit b54d9a1

Browse files
Ahmed Khaledcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 880961487
1 parent 38fcbb8 commit b54d9a1

File tree

3 files changed

+6
-7
lines changed

3 files changed

+6
-7
lines changed

init2winit/dataset_lib/fineweb_edu_10b_input_pipeline.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
https://github.com/mlcommons/algorithmic-efficiency/blob/main/datasets/dataset_setup.py.
2323
"""
2424

25-
# import tensorflow.compat.v2 as tf
2625
import os
2726
from absl import logging
2827
from ml_collections.config_dict import config_dict
@@ -35,8 +34,7 @@
3534
SHUFFLE_BUFFER_SIZE = 100_000
3635
VOCAB_SIZE = 50_257
3736

38-
PAD_ID = tf.constant(-1, dtype=tf.int64)
39-
# PAD_ID = -1
37+
PAD_ID = -1
4038

4139
AUTOTUNE = tf.data.experimental.AUTOTUNE
4240

@@ -61,6 +59,8 @@ def batch_with_padding(
6159

6260
# tf.data.Dataset.padded.batch pads elements in the batch so we call it
6361
# again with batch_size=1 to pad each element in original batch.
62+
if isinstance(padding_id, int):
63+
padding_id = tf.constant(padding_id, dtype=tf.int64)
6464
padded_batched_dataset = batched_dataset.padded_batch(
6565
1, padded_shapes=padded_shapes, padding_values=padding_id
6666
)
@@ -95,7 +95,6 @@ def get_fineweb_edu_dataset(
9595
train_path = os.path.join(DATA_DIR, TRAIN_DIR)
9696
val_path = os.path.join(DATA_DIR, VAL_DIR)
9797

98-
# Load datasets and cast to int32.
9998
train_dataset = tf.data.Dataset.load(train_path)
10099
val_dataset = tf.data.Dataset.load(val_path)
101100

init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class FinewebEdu10bInputPipelineTest(absltest.TestCase):
2626

2727
def test_batch_with_padding(self):
2828
"""Test batching with padding."""
29-
arr = np.arange(18, dtype=np.int32)
29+
arr = np.arange(18, dtype=np.int64)
3030
ds = tf.data.Dataset.from_tensor_slices(arr)
3131
ds = ds.batch(
3232
6,

init2winit/dataset_lib/test_fineweb_edu_10b_mdlm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def test_eval_batch_padding_applied(self):
133133
self.assertLen(batches, 2)
134134

135135
padded_batch = batches[1]
136-
pad_id = int(input_pipeline.PAD_ID.numpy())
136+
pad_id = input_pipeline.PAD_ID
137137

138138
# The second row of the padded batch should be all PAD_ID.
139139
np.testing.assert_array_equal(padded_batch['inputs'][1], np.full(4, pad_id))
@@ -157,7 +157,7 @@ def test_eval_batch_padding_not_in_full_batches(self):
157157

158158
batches = list(valid_ds.as_numpy_iterator())
159159
full_batch = batches[0]
160-
pad_id = int(input_pipeline.PAD_ID.numpy())
160+
pad_id = input_pipeline.PAD_ID
161161

162162
# No element in the full batch should be PAD_ID.
163163
self.assertTrue(np.all(full_batch['inputs'] != pad_id))

0 commit comments

Comments
 (0)