Skip to content

Conversation

@yiwzhao
Copy link
Collaborator

@yiwzhao yiwzhao commented Jan 7, 2026

feat: Add data collators to support embedding classification.
split the data collator part from #322

@github-actions github-actions bot added the ci label Jan 7, 2026
@yiwzhao yiwzhao changed the title add data collator for sequence classification and unit tests feat: Add data collator to support embedding classification. Jan 7, 2026
@yiwzhao yiwzhao changed the title feat: Add data collator to support embedding classification. feat: Add data collators to support embedding classification. Jan 7, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces two new data collators, ClassificationDataCollatorWithPositionIDs and ClassificationTextSequenceShardCollator, specifically for sequence classification tasks. It also adds a comprehensive set of unit tests to validate their functionality. The new collators correctly handle labels for classification by omitting the label masking and shifting present in the base collators. The unit tests are well-written, covering various scenarios including sequence parallelism.

The primary concern with this PR is the significant code duplication in the implementation of the new collator classes. Both classes are nearly identical copies of existing ones, which poses a maintainability risk. I have provided detailed feedback and suggestions to refactor the code using inheritance to mitigate this issue, and also recommended a more robust long-term solution of making the base collators more flexible.

@yiwzhao yiwzhao changed the title feat: Add data collators to support embedding classification. [data, collator] feat: Add data collators to support embedding classification. Jan 7, 2026
@yiwzhao yiwzhao changed the title [data, collator] feat: Add data collators to support embedding classification. [data] feat: add data collators for embedding classification. Jan 7, 2026
@yiwzhao yiwzhao self-assigned this Jan 7, 2026


@dataclass
class ClassificationDataCollatorWithPositionIDs(DataCollator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe reuse DataCollatorWithPositionIDs if no diff here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a difference: DataCollatorWithPositionIDs masks the labels corresponding to the "boundary tokens" of each packed subsequence, while ClassificationDataCollatorWithPositionIDs does not perform this masking step. However, I think we can make "whether to mask boundary labels" a configurable option, so we don't need to create a new class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

input_ids = batch.pop("input_ids")

# CHANGED: do NOT shift labels for seq-cls token-level labels
labels = batch.pop("labels").contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we cound add a parameter named ‘shift_labels=True` in TextSequenceShardCollator to control whether to keep the labels or shift and mask the labels

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I added two switches to the original class, which by default maintain the old behavior while also supporting sequence classification.


if self.mask_boundary_labels:
if self.rmpad_with_pos_ids: # mask the last token of each sequence
cu_seqlens = pos2culen(batch["position_ids"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need cu_seqlens for not mask_boundary_labels case as well?

batch = collator(features_two_samples)

assert batch["input_ids"].shape == (1, 5)
assert "position_ids" in batch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we assert actual values and in other tests?

import torch


IGNORE_INDEX = -100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we import veomni constant

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants