fix: GPTSFTChatDataset.collate_fn use pad_seq_length_to_mult instead of hardcoded 16#2643
fix: GPTSFTChatDataset.collate_fn use pad_seq_length_to_mult instead of hardcoded 16#2643shanecmoran wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
📝 WalkthroughWalkthroughFixed Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
15f1282 to
0a9e6bc
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit_tests/data/datasets/test_sft.py (1)
328-373: Add a unit test marker for this new test.Please annotate this method with
@pytest.mark.unitto match test categorization conventions.Suggested update
+ `@pytest.mark.unit` def test_collate_fn_respects_pad_seq_length_to_mult(self, tmp_path):As per coding guidelines,
tests/**/*.py: Usepytest.markto categorize tests (unit, integration, system).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/data/datasets/test_sft.py` around lines 328 - 373, The test method test_collate_fn_respects_pad_seq_length_to_mult in tests/unit_tests/data/datasets/test_sft.py needs to be annotated with the unit test marker: add `@pytest.mark.unit` immediately above the def line for test_collate_fn_respects_pad_seq_length_to_mult and ensure pytest is imported at the top of the file (add "import pytest" if missing) so the decorator resolves.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unit_tests/data/datasets/test_sft.py`:
- Around line 328-373: The test method
test_collate_fn_respects_pad_seq_length_to_mult in
tests/unit_tests/data/datasets/test_sft.py needs to be annotated with the unit
test marker: add `@pytest.mark.unit` immediately above the def line for
test_collate_fn_respects_pad_seq_length_to_mult and ensure pytest is imported at
the top of the file (add "import pytest" if missing) so the decorator resolves.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 5a7f0bb4-88da-44b1-9a6f-831a0e07d3b6
📒 Files selected for processing (2)
src/megatron/bridge/data/datasets/sft.pytests/unit_tests/data/datasets/test_sft.py
…of hardcoded 16 GPTSFTChatDataset.collate_fn hardcodes sequence padding to a multiple of 16 instead of using the configurable self.pad_seq_length_to_mult attribute. This is inconsistent with the base class GPTSFTDataset and GPTSFTPackedDataset, both of which correctly use the attribute. The hardcoded value causes failures when using variable-length sequences (pad_to_max_length=False) with context parallelism sizes where 2 * cp_size > 16 (e.g. CP=16 requires divisibility by 32). Fixes: NVIDIA-NeMo#2642 Signed-off-by: Shane Moran <shane.moran@shopify.com>
0a9e6bc to
47bdded
Compare
| else: | ||
| max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16)) | ||
| max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult)) | ||
| assert max_length <= self.max_seq_length |
There was a problem hiding this comment.
change to max(16, self.pad_seq_length_to_mult)?
we always want to pad to multiple of 16, otherwise will end up with a slower kernel
Description
GPTSFTChatDataset.collate_fnhardcodes sequence padding to a multiple of 16, ignoring the configurableself.pad_seq_length_to_multattribute. The base classGPTSFTDataset.collate_fnandGPTSFTPackedDataset.collate_fnboth correctly useself.pad_seq_length_to_mult— only the chat variant has the bug.This causes failures when using variable-length sequences (
pad_to_max_length=False) with context parallelism sizes where2 * cp_size > 16(e.g. CP=16 requires divisibility by 32).Changes
src/megatron/bridge/data/datasets/sft.py(line 1200):tests/unit_tests/data/datasets/test_sft.py:test_collate_fn_respects_pad_seq_length_to_multtoTestDataGPTSFTChatDataset— constructs dataset withpad_seq_length_to_mult=32, callscollate_fn, asserts output sequence length is divisible by 32.Fixes #2642
Summary by CodeRabbit
Improvements
Tests