Skip to content

fix: GPTSFTChatDataset.collate_fn use pad_seq_length_to_mult instead of hardcoded 16#2643

Open
shanecmoran wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
shanecmoran:fix/chat-collate-padding
Open

fix: GPTSFTChatDataset.collate_fn use pad_seq_length_to_mult instead of hardcoded 16#2643
shanecmoran wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
shanecmoran:fix/chat-collate-padding

Conversation

@shanecmoran
Copy link

@shanecmoran shanecmoran commented Mar 4, 2026

Description

GPTSFTChatDataset.collate_fn hardcodes sequence padding to a multiple of 16, ignoring the configurable self.pad_seq_length_to_mult attribute. The base class GPTSFTDataset.collate_fn and GPTSFTPackedDataset.collate_fn both correctly use self.pad_seq_length_to_mult — only the chat variant has the bug.

This causes failures when using variable-length sequences (pad_to_max_length=False) with context parallelism sizes where 2 * cp_size > 16 (e.g. CP=16 requires divisibility by 32).

Changes

src/megatron/bridge/data/datasets/sft.py (line 1200):

# Before — hardcoded 16
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16))

# After — uses configurable attribute, consistent with base class
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult))

tests/unit_tests/data/datasets/test_sft.py:

  • Added test_collate_fn_respects_pad_seq_length_to_mult to TestDataGPTSFTChatDataset — constructs dataset with pad_seq_length_to_mult=32, calls collate_fn, asserts output sequence length is divisible by 32.

Fixes #2642

Summary by CodeRabbit

  • Improvements

    • Sequence padding mechanism updated to use configurable granularity parameters instead of hardcoded values, enabling dataset instances to specify custom padding requirements for improved data preparation flexibility.
  • Tests

    • Added unit test to validate that sequence padding operations properly respect and enforce configured granularity settings during batch processing.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@shanecmoran shanecmoran marked this pull request as ready for review March 4, 2026 17:48
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

📝 Walkthrough

Walkthrough

Fixed GPTSFTChatDataset.collate_fn to use configurable pad_seq_length_to_mult instead of hardcoded 16 for sequence padding, aligning behavior with other dataset classes. Added unit test to validate padding respects the parameter.

Changes

Cohort / File(s) Summary
Sequence Padding Fix
src/megatron/bridge/data/datasets/sft.py
Replaced hardcoded padding granularity of 16 with configurable self.pad_seq_length_to_mult parameter in GPTSFTChatDataset.collate_fn to match behavior of base class implementations.
Test Coverage
tests/unit_tests/data/datasets/test_sft.py
Added test_collate_fn_respects_pad_seq_length_to_mult to verify that collated token sequences are padded to multiples of the configured pad_seq_length_to_mult value.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main fix: replacing hardcoded 16 with pad_seq_length_to_mult in GPTSFTChatDataset.collate_fn.
Linked Issues check ✅ Passed The PR fully addresses all coding requirements from issue #2642: fixes the hardcoded 16 in GPTSFTChatDataset.collate_fn, adds test to verify pad_seq_length_to_mult is respected, and ensures consistency with other collate_fn implementations.
Out of Scope Changes check ✅ Passed All changes are directly scoped to issue #2642: one-line fix in sft.py and corresponding test addition, with no unrelated modifications.
Test Results For Major Changes ✅ Passed PR contains a minor bug fix changing hardcoded value to configurable attribute with appropriate test coverage added.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@shanecmoran shanecmoran force-pushed the fix/chat-collate-padding branch from 15f1282 to 0a9e6bc Compare March 4, 2026 19:00
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit_tests/data/datasets/test_sft.py (1)

328-373: Add a unit test marker for this new test.

Please annotate this method with @pytest.mark.unit to match test categorization conventions.

Suggested update
+    `@pytest.mark.unit`
     def test_collate_fn_respects_pad_seq_length_to_mult(self, tmp_path):

As per coding guidelines, tests/**/*.py: Use pytest.mark to categorize tests (unit, integration, system).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/data/datasets/test_sft.py` around lines 328 - 373, The test
method test_collate_fn_respects_pad_seq_length_to_mult in
tests/unit_tests/data/datasets/test_sft.py needs to be annotated with the unit
test marker: add `@pytest.mark.unit` immediately above the def line for
test_collate_fn_respects_pad_seq_length_to_mult and ensure pytest is imported at
the top of the file (add "import pytest" if missing) so the decorator resolves.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unit_tests/data/datasets/test_sft.py`:
- Around line 328-373: The test method
test_collate_fn_respects_pad_seq_length_to_mult in
tests/unit_tests/data/datasets/test_sft.py needs to be annotated with the unit
test marker: add `@pytest.mark.unit` immediately above the def line for
test_collate_fn_respects_pad_seq_length_to_mult and ensure pytest is imported at
the top of the file (add "import pytest" if missing) so the decorator resolves.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 5a7f0bb4-88da-44b1-9a6f-831a0e07d3b6

📥 Commits

Reviewing files that changed from the base of the PR and between dd0a1d9 and 0a9e6bc.

📒 Files selected for processing (2)
  • src/megatron/bridge/data/datasets/sft.py
  • tests/unit_tests/data/datasets/test_sft.py

…of hardcoded 16

GPTSFTChatDataset.collate_fn hardcodes sequence padding to a multiple
of 16 instead of using the configurable self.pad_seq_length_to_mult
attribute. This is inconsistent with the base class GPTSFTDataset and
GPTSFTPackedDataset, both of which correctly use the attribute.

The hardcoded value causes failures when using variable-length sequences
(pad_to_max_length=False) with context parallelism sizes where
2 * cp_size > 16 (e.g. CP=16 requires divisibility by 32).

Fixes: NVIDIA-NeMo#2642
Signed-off-by: Shane Moran <shane.moran@shopify.com>
@shanecmoran shanecmoran force-pushed the fix/chat-collate-padding branch from 0a9e6bc to 47bdded Compare March 5, 2026 11:57
else:
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16))
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult))
assert max_length <= self.max_seq_length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to max(16, self.pad_seq_length_to_mult)?

we always want to pad to multiple of 16, otherwise will end up with a slower kernel

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPTSFTChatDataset.collate_fn ignores pad_seq_length_to_mult

2 participants