Skip to content

Conversation

@yiwzhao
Copy link
Collaborator

@yiwzhao yiwzhao commented Dec 24, 2025

feat: Add a data collator to support embedding classification.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for embedding classification by adding new data collators and a sequence classification head for the Qwen3 model. The changes are well-structured, but I've identified significant code duplication in the new data collators in veomni/data/data_collator.py. My review includes suggestions to refactor these classes using inheritance to improve maintainability and reduce redundancy. I also found some dead code that should be removed for clarity.

@yiwzhao yiwzhao marked this pull request as draft December 24, 2025 02:07
@yiwzhao yiwzhao marked this pull request as ready for review January 6, 2026 01:04
@yiwzhao yiwzhao force-pushed the yiwen/data_collator branch from 99f7c42 to bd6e36f Compare January 6, 2026 01:23
@CLAassistant
Copy link

CLAassistant commented Jan 6, 2026

CLA assistant check
All committers have signed the CLA.

@yiwzhao yiwzhao marked this pull request as draft January 6, 2026 01:30
@yiwzhao yiwzhao force-pushed the yiwen/data_collator branch from bd6e36f to 99f7c42 Compare January 6, 2026 01:38
@yiwzhao yiwzhao marked this pull request as ready for review January 6, 2026 01:39


@dataclass
class ClassificationDataCollatorWithPositionIDs(DataCollator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split this to another MR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to #376


arch_name = get_model_arch_from_config(model_config)
model_type = model_config.model_type
if not force_use_huggingface:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebase this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already rebased.

**kwargs,
) -> torch.Tensor:
# We don't use shift_labels
assert shift_labels is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert can be skipped. do not use in production code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved.

loss = None
logits = None

if labels is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throw exception if label is None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been implemented that a ValueError will be raised if the label is none now.

return loss, logits


def seqcls_token_loss_function(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is no longer there. can we follow the new way defined in https://github.com/ByteDance-Seed/VeOmni/blob/main/veomni/ops/fused_cross_entropy/__init__.py

Copy link
Collaborator Author

@yiwzhao yiwzhao Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can. Implemented using the latest way.

)

hidden_states = transformer_outputs.last_hidden_state
logits = self.score(hidden_states)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this. no longer needed. we can just use the one from the loss_function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.

**kwargs: Unpack[FlashAttentionKwargs],
) -> SequenceClassifierOutputWithPast:
transformer: Qwen3Model = getattr(self, self.base_model_prefix)
transformer_outputs: BaseModelOutputWithPast = transformer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.model(...)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has been revised to a simpler version as suggested.

labels: torch.Tensor,
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove shift_labels. kwargs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed.


loss, logits = m.seqcls_token_loss_function(hidden_states, weight, labels=labels, ignore_index=-100)

assert loss is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unit tests for the loss function have been adjusted according to this document

ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added.

@yiwzhao yiwzhao changed the title feat: Add a data collator to support embedding classification. feat: Add class Qwen3ForSequenceClassification and loss function to support embedding classification. Jan 7, 2026
@yiwzhao yiwzhao changed the title feat: Add class Qwen3ForSequenceClassification and loss function to support embedding classification. [model, loss] feat: Add class Qwen3ForSequenceClassification and loss function to support embedding classification. Jan 7, 2026
@yiwzhao yiwzhao changed the title [model, loss] feat: Add class Qwen3ForSequenceClassification and loss function to support embedding classification. [model, loss] feat: add Qwen3 sequence classification model and loss for embedding classification. Jan 7, 2026
@yiwzhao yiwzhao changed the title [model, loss] feat: add Qwen3 sequence classification model and loss for embedding classification. [model, ops] feat: add Qwen3 sequence classification model and loss for embedding classification. Jan 7, 2026
weights=self.score.weight,
**kwargs,
)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if inference task

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added logic to calculate logits when no labels are provided, for compatibility with inference tasks.

@@ -0,0 +1,210 @@
import math
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I dont understand what this test does..

Copy link
Collaborator

@piyifan123 piyifan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Coach257 do you want to take another look?

hidden_states = kwargs.pop("hidden_states", None)
weights = kwargs.pop("weights", None)

assert hidden_states is not None or logits is not None, "hidden_states or logits must be provided."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
device = torch.device("cuda")
monkeypatch.setattr(m, "get_parallel_state", lambda: _FakePS(sp_enabled=False))
ignore = -100
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import veomni constant IGNORE_INDEX

hidden_states=hidden_states,
weights=weights,
)
expected = math.log(float(3))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just write down the one line torch command to do the matmul + softmax + cross entropy

logits = torch.zeros((1, 2, 3), device=device)
labels = torch.tensor([[ignore, 1]], device=device)
hidden_states = torch.zeros((1, 2, 5), device=device)
weights = torch.zeros((3, 5), device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make it a matrix.

@@ -0,0 +1,210 @@
import math
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants