-
Notifications
You must be signed in to change notification settings - Fork 199
Open
Labels
Description
Description
GPTSFTChatDataset.collate_fn hardcodes sequence padding to a multiple of 16, ignoring the configurable self.pad_seq_length_to_mult attribute. This causes a crash when using variable-length sequences (pad_to_max_length=False) with context parallelism sizes where 2 * cp_size > 16 (e.g. CP=16 requires divisibility by 32).
The base class GPTSFTDataset.collate_fn (line 703) and GPTSFTPackedDataset.collate_fn (line 892) both correctly use self.pad_seq_length_to_mult — only the chat variant has the bug.
Location
src/megatron/bridge/data/datasets/sft.py, line 1200:
# BUGGY — hardcoded 16
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16))Should be:
# CORRECT — uses configurable attribute
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult))How to Reproduce
- Create a
GPTSFTChatDatasetwithpad_seq_length_to_mult=32andpad_to_max_length=False - Call
collate_fnon a batch - Observe output sequence length is padded to a multiple of 16, not 32
Expected Behavior
Output sequence length should be padded to a multiple of pad_seq_length_to_mult (32 in this case), consistent with the base class behavior.
Reactions are currently unavailable