Skip to content

GPTSFTChatDataset.collate_fn ignores pad_seq_length_to_mult #2642

@shanecmoran

Description

@shanecmoran

Description

GPTSFTChatDataset.collate_fn hardcodes sequence padding to a multiple of 16, ignoring the configurable self.pad_seq_length_to_mult attribute. This causes a crash when using variable-length sequences (pad_to_max_length=False) with context parallelism sizes where 2 * cp_size > 16 (e.g. CP=16 requires divisibility by 32).

The base class GPTSFTDataset.collate_fn (line 703) and GPTSFTPackedDataset.collate_fn (line 892) both correctly use self.pad_seq_length_to_mult — only the chat variant has the bug.

Location

src/megatron/bridge/data/datasets/sft.py, line 1200:

# BUGGY — hardcoded 16
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16))

Should be:

# CORRECT — uses configurable attribute
max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, self.pad_seq_length_to_mult))

How to Reproduce

  1. Create a GPTSFTChatDataset with pad_seq_length_to_mult=32 and pad_to_max_length=False
  2. Call collate_fn on a batch
  3. Observe output sequence length is padded to a multiple of 16, not 32

Expected Behavior

Output sequence length should be padded to a multiple of pad_seq_length_to_mult (32 in this case), consistent with the base class behavior.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions