-
Notifications
You must be signed in to change notification settings - Fork 127
[data] feat: add data collators for embedding classification. #376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces two new data collators, ClassificationDataCollatorWithPositionIDs and ClassificationTextSequenceShardCollator, specifically for sequence classification tasks. It also adds a comprehensive set of unit tests to validate their functionality. The new collators correctly handle labels for classification by omitting the label masking and shifting present in the base collators. The unit tests are well-written, covering various scenarios including sequence parallelism.
The primary concern with this PR is the significant code duplication in the implementation of the new collator classes. Both classes are nearly identical copies of existing ones, which poses a maintainability risk. I have provided detailed feedback and suggestions to refactor the code using inheritance to mitigate this issue, and also recommended a more robust long-term solution of making the base collators more flexible.
veomni/data/data_collator.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class ClassificationDataCollatorWithPositionIDs(DataCollator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe reuse DataCollatorWithPositionIDs if no diff here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a difference: DataCollatorWithPositionIDs masks the labels corresponding to the "boundary tokens" of each packed subsequence, while ClassificationDataCollatorWithPositionIDs does not perform this masking step. However, I think we can make "whether to mask boundary labels" a configurable option, so we don't need to create a new class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
veomni/data/data_collator.py
Outdated
| input_ids = batch.pop("input_ids") | ||
|
|
||
| # CHANGED: do NOT shift labels for seq-cls token-level labels | ||
| labels = batch.pop("labels").contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we cound add a parameter named ‘shift_labels=True` in TextSequenceShardCollator to control whether to keep the labels or shift and mask the labels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I added two switches to the original class, which by default maintain the old behavior while also supporting sequence classification.
|
|
||
| if self.mask_boundary_labels: | ||
| if self.rmpad_with_pos_ids: # mask the last token of each sequence | ||
| cu_seqlens = pos2culen(batch["position_ids"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we need cu_seqlens for not mask_boundary_labels case as well?
| batch = collator(features_two_samples) | ||
|
|
||
| assert batch["input_ids"].shape == (1, 5) | ||
| assert "position_ids" in batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we assert actual values and in other tests?
| import torch | ||
|
|
||
|
|
||
| IGNORE_INDEX = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we import veomni constant
feat: Add data collators to support embedding classification.
split the data collator part from #322