Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def _postprocess(
output_weight = None
if self.share_embeddings_and_output_weights:
output_weight = self.shared_embedding_or_output_weight()
if self.config.is_multimodal and self.config.context_parallel_size > 1:
if self.config.is_multimodal and self.config.context_parallel_size > 1 and input_ids is not None:
# input_ids is required by MTP.
Comment on lines +402 to +403
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The addition of the input_ids is not None check prevents a crash in split_cp_inputs when input_ids are missing (e.g., during inference with decoder_input provided). However, the comment on line 403 states that input_ids is required by MTP. If this is the case, the MTP block starting at line 406 might still fail if input_ids is None while labels are present. It would be more robust to also check for input_ids at line 406 or clarify if MTP can indeed function without them.

input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)

if self.mtp_process and labels is not None:
Expand Down
Loading