From 260a7519f994572fe6c935550f65511897d066c3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Apr 2026 11:21:17 +0800 Subject: [PATCH] fix mm mtp --- src/mcore_bridge/model/gpt_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 4c108fb..bf10c03 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -399,7 +399,8 @@ def _postprocess( output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if self.config.is_multimodal and self.config.context_parallel_size > 1: + if self.config.is_multimodal and self.config.context_parallel_size > 1 and input_ids is not None: + # input_ids is required by MTP. input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) if self.mtp_process and labels is not None: