diff --git a/src/megatron/bridge/data/datasets/sft.py b/src/megatron/bridge/data/datasets/sft.py index f29900223b..b12d564474 100644 --- a/src/megatron/bridge/data/datasets/sft.py +++ b/src/megatron/bridge/data/datasets/sft.py @@ -1197,7 +1197,7 @@ def collate_fn(self, batch): if self.pad_to_max_length: max_length = self.max_seq_length else: - max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16)) + max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult)) assert max_length <= self.max_seq_length position_ids = [list(range(max_length)) for _ in batch] diff --git a/tests/unit_tests/data/datasets/test_sft.py b/tests/unit_tests/data/datasets/test_sft.py index c819a7668a..0eb3b67f0a 100755 --- a/tests/unit_tests/data/datasets/test_sft.py +++ b/tests/unit_tests/data/datasets/test_sft.py @@ -325,6 +325,52 @@ def test_collate_fn(self, tmp_path): ] dataset.collate_fn(batch) + def test_collate_fn_respects_pad_seq_length_to_mult(self, tmp_path): + datasets_dir = tmp_path / "datasets" + datasets_dir.mkdir(exist_ok=True) + path = str(datasets_dir / "sft.jsonl") + line = {"input": "hi how are you?", "output": "I'm fine, thanks."} + with open(path, "w") as f: + for _ in range(10): + f.write(json.dumps(line) + "\n") + + tokenizer = create_mock_tokenizer() + dataset = GPTSFTChatDataset( + file_path=path, + tokenizer=tokenizer, + label_key="output", + prompt_template="{input}\n\n### Response:\n{output}", + truncation_field="output", + pad_seq_length_to_mult=32, + ) + batch = [ + { + "input_ids": np.array([101, 102, 103, 104, 105]), + "context_ids": np.array([101, 102]), + "answer_start_idx": 2, + "context_length": 2, + "answer_ids": np.array([104, 105]), + "seq_boundaries": (0, 3), + "loss_mask": np.array([0, 0, 0, 1, 1]), + "metadata": {}, + "token_count": 5, + }, + { + "input_ids": np.array([201, 202, 203, 204]), + "context_ids": np.array([201]), + "answer_start_idx": 1, + "context_length": 1, + "answer_ids": np.array([203, 204]), + "seq_boundaries": (0, 2), + "loss_mask": np.array([0, 0, 1, 1]), + "metadata": {}, + "token_count": 4, + }, + ] + result = dataset.collate_fn(batch) + seq_length = result["tokens"].shape[1] + assert seq_length % 32 == 0, f"Expected sequence length divisible by 32, got {seq_length}" + def test_build_samples_mapping(self, tmp_path): dataset, _ = get_gpt_sft(tmp_path, dataset_type="chat") dataset._build_samples_mapping()