diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 4c108fb..bf10c03 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -399,7 +399,8 @@ def _postprocess( output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - if self.config.is_multimodal and self.config.context_parallel_size > 1: + if self.config.is_multimodal and self.config.context_parallel_size > 1 and input_ids is not None: + # input_ids is required by MTP. input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) if self.mtp_process and labels is not None: