-
Notifications
You must be signed in to change notification settings - Fork 128
[model, ops] feat: add Qwen3 sequence classification model and loss for embedding classification. #322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for embedding classification by adding new data collators and a sequence classification head for the Qwen3 model. The changes are well-structured, but I've identified significant code duplication in the new data collators in veomni/data/data_collator.py. My review includes suggestions to refactor these classes using inheritance to improve maintainability and reduce redundancy. I also found some dead code that should be removed for clarity.
99f7c42 to
bd6e36f
Compare
bd6e36f to
99f7c42
Compare
veomni/data/data_collator.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class ClassificationDataCollatorWithPositionIDs(DataCollator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split this to another MR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to #376
veomni/models/loader.py
Outdated
|
|
||
| arch_name = get_model_arch_from_config(model_config) | ||
| model_type = model_config.model_type | ||
| if not force_use_huggingface: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rebase this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already rebased.
veomni/ops/loss.py
Outdated
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| # We don't use shift_labels | ||
| assert shift_labels is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert can be skipped. do not use in production code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
resolved.
veomni/ops/loss.py
Outdated
| loss = None | ||
| logits = None | ||
|
|
||
| if labels is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
throw exception if label is None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been implemented that a ValueError will be raised if the label is none now.
veomni/ops/loss.py
Outdated
| return loss, logits | ||
|
|
||
|
|
||
| def seqcls_token_loss_function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this file is no longer there. can we follow the new way defined in https://github.com/ByteDance-Seed/VeOmni/blob/main/veomni/ops/fused_cross_entropy/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can. Implemented using the latest way.
| ) | ||
|
|
||
| hidden_states = transformer_outputs.last_hidden_state | ||
| logits = self.score(hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this. no longer needed. we can just use the one from the loss_function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
| **kwargs: Unpack[FlashAttentionKwargs], | ||
| ) -> SequenceClassifierOutputWithPast: | ||
| transformer: Qwen3Model = getattr(self, self.base_model_prefix) | ||
| transformer_outputs: BaseModelOutputWithPast = transformer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.model(...)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been revised to a simpler version as suggested.
veomni/ops/loss.py
Outdated
| labels: torch.Tensor, | ||
| num_items_in_batch: Optional[int] = None, | ||
| ignore_index: int = -100, | ||
| shift_labels: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove shift_labels. kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
tests/data/test_seqcls_loss.py
Outdated
|
|
||
| loss, logits = m.seqcls_token_loss_function(hidden_states, weight, labels=labels, ignore_index=-100) | ||
|
|
||
| assert loss is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unit tests for the loss function have been adjusted according to this document
veomni/ops/loss.py
Outdated
| ignore_index: int = -100, | ||
| shift_labels: Optional[torch.Tensor] = None, | ||
| **kwargs, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added.
| weights=self.score.weight, | ||
| **kwargs, | ||
| ) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if inference task
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added logic to calculate logits when no labels are provided, for compatibility with inference tasks.
| @@ -0,0 +1,210 @@ | |||
| import math | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tbh I dont understand what this test does..
piyifan123
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Coach257 do you want to take another look?
| hidden_states = kwargs.pop("hidden_states", None) | ||
| weights = kwargs.pop("weights", None) | ||
|
|
||
| assert hidden_states is not None or logits is not None, "hidden_states or logits must be provided." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | ||
| device = torch.device("cuda") | ||
| monkeypatch.setattr(m, "get_parallel_state", lambda: _FakePS(sp_enabled=False)) | ||
| ignore = -100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import veomni constant IGNORE_INDEX
| hidden_states=hidden_states, | ||
| weights=weights, | ||
| ) | ||
| expected = math.log(float(3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just write down the one line torch command to do the matmul + softmax + cross entropy
| logits = torch.zeros((1, 2, 3), device=device) | ||
| labels = torch.tensor([[ignore, 1]], device=device) | ||
| hidden_states = torch.zeros((1, 2, 5), device=device) | ||
| weights = torch.zeros((3, 5), device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make it a matrix.
| @@ -0,0 +1,210 @@ | |||
| import math | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if we need to add tests/ops folder to https://github.com/ByteDance-Seed/VeOmni/blob/main/.github/workflows/gpu_unit_tests.yml
feat: Add a data collator to support embedding classification.