From 959d6b87c127e2a889ebb87886e502de5227e813 Mon Sep 17 00:00:00 2001 From: Ao Tang Date: Tue, 3 Feb 2026 10:18:26 -0800 Subject: [PATCH 1/3] Init version for kimi k25 Signed-off-by: Ao Tang --- .../conversion/hf_to_megatron_generate_vlm.py | 35 +-- src/megatron/bridge/models/__init__.py | 9 + .../bridge/models/conversion/auto_bridge.py | 5 +- .../bridge/models/kimi_vl/__init__.py | 24 ++ .../models/kimi_vl/kimi_k25_vl_bridge.py | 110 ++++++++ .../models/kimi_vl/kimi_k25_vl_provider.py | 110 ++++++++ .../models/kimi_vl/modeling_kimi_k25_vl.py | 249 ++++++++++++++++++ 7 files changed, 524 insertions(+), 18 deletions(-) create mode 100644 src/megatron/bridge/models/kimi_vl/__init__.py create mode 100644 src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py create mode 100644 src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py create mode 100644 src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py diff --git a/examples/conversion/hf_to_megatron_generate_vlm.py b/examples/conversion/hf_to_megatron_generate_vlm.py index 71583afb88..2bf7888cfb 100644 --- a/examples/conversion/hf_to_megatron_generate_vlm.py +++ b/examples/conversion/hf_to_megatron_generate_vlm.py @@ -152,30 +152,31 @@ def process_image_inputs(processor, image_path: Optional[str], prompt: str): { "role": "user", "content": [ - {"type": "image", "image": image_path}, + {"type": "image", "image_url": image_path}, {"type": "text", "text": prompt}, ], } ] # Process vision info - image_inputs, video_inputs = process_vision_info(messages) - - # Apply chat template - text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - # Process inputs - inputs = processor( - text=[text], - images=image_inputs, - videos=video_inputs, - padding=True, - return_tensors="pt", - ) + # image_inputs, video_inputs = process_vision_info(messages) + + # # Apply chat template + # text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # # Process inputs + # inputs = processor( + # text=[text], + # images=image_inputs, + # videos=video_inputs, + # padding=True, + # return_tensors="pt", + # ) + inputs = processor(messages=messages) return ( inputs.input_ids, inputs.pixel_values, - getattr(inputs, "image_grid_thw", None), + getattr(inputs, "image_grid_thw", None) or getattr(inputs, "grid_thws", None), getattr(inputs, "image_sizes", None), messages, ) @@ -209,7 +210,7 @@ def main(args) -> None: # We still need HF config for tokenizer, but we'll load the model from Megatron checkpoint # Create bridge from HF config only (no weights) - bridge = AutoBridge.from_hf_pretrained(args.hf_model_path) + bridge = AutoBridge.from_hf_pretrained(args.hf_model_path, trust_remote_code=args.trust_remote_code) # Initialize model parallel before loading model_provider = bridge.to_megatron_provider(load_weights=False) @@ -236,7 +237,7 @@ def main(args) -> None: else: # Load from HuggingFace and convert to Megatron print_rank_0(f"Loading HuggingFace model from: {args.hf_model_path}") - bridge = AutoBridge.from_hf_pretrained(args.hf_model_path) + bridge = AutoBridge.from_hf_pretrained(args.hf_model_path, trust_remote_code=args.trust_remote_code) model_provider = bridge.to_megatron_provider(load_weights=True) model_provider.tensor_model_parallel_size = tp model_provider.pipeline_model_parallel_size = pp diff --git a/src/megatron/bridge/models/__init__.py b/src/megatron/bridge/models/__init__.py index 3171e35d0f..bd99f59feb 100644 --- a/src/megatron/bridge/models/__init__.py +++ b/src/megatron/bridge/models/__init__.py @@ -75,6 +75,11 @@ GPTOSSProvider120B, ) from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.kimi_vl import ( + KimiK25VLBridge, + KimiK25VLModelProvider, + KimiK25VLModel, +) from megatron.bridge.models.llama import ( CodeLlamaModelProvider7B, CodeLlamaModelProvider13B, @@ -229,6 +234,10 @@ "GPTOSSProvider20B", "GPTOSSProvider120B", "T5ModelProvider", + "KimiK2Provider", + "KimiK25VLModel", + "KimiK25VLBridge", + "KimiK25VLModelProvider", "LlamaModelProvider", "Llama2ModelProvider7B", "Llama2ModelProvider13B", diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 0251fb9ee3..2faeef5f37 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -732,9 +732,12 @@ def import_ckpt( megatron_model = bridge.to_megatron_model(wrap_with_ddp=False, use_cpu_initialization=True) # Save as Megatron checkpoint - hf_tokenizer_kwargs = None + hf_tokenizer_kwargs = {} if hasattr(bridge._model_bridge, "get_hf_tokenizer_kwargs"): hf_tokenizer_kwargs = bridge._model_bridge.get_hf_tokenizer_kwargs() + # Pass trust_remote_code to tokenizer if provided in kwargs + if kwargs.get("trust_remote_code"): + hf_tokenizer_kwargs["trust_remote_code"] = True bridge.save_megatron_model( megatron_model, megatron_path, diff --git a/src/megatron/bridge/models/kimi_vl/__init__.py b/src/megatron/bridge/models/kimi_vl/__init__.py new file mode 100644 index 0000000000..df2f2a6896 --- /dev/null +++ b/src/megatron/bridge/models/kimi_vl/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.bridge.models.kimi_vl.kimi_k25_vl_bridge import KimiK25VLBridge +from megatron.bridge.models.kimi_vl.kimi_k25_vl_provider import KimiK25VLModelProvider +from megatron.bridge.models.kimi_vl.modeling_kimi_k25_vl import KimiK25VLModel + + +__all__ = [ + "KimiK25VLModel", + "KimiK25VLBridge", + "KimiK25VLModelProvider", +] diff --git a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py new file mode 100644 index 0000000000..a931eb5650 --- /dev/null +++ b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from transformers import Gemma3ForConditionalGeneration + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, + ReplicatedMapping, +) +from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.kimi_vl.kimi_k25_vl_provider import KimiK25VLModelProvider +from megatron.bridge.models.kimi_vl.modeling_kimi_k25_vl import KimiK25VLModel +from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list + +@MegatronModelBridge.register_bridge(source="KimiK25ForConditionalGeneration", target=KimiK25VLModel) +class KimiK25VLBridge(MegatronModelBridge): + """ + Megatron Bridge for Kimi K2.5 VL. + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> KimiK25VLModelProvider: + hf_config = hf_pretrained.config + text_config = hf_config.text_config + vision_config = hf_config.vision_config + + # get_common_configs expects TextConfig + hf_pretrained.config = text_config + configs = get_common_configs(hf_pretrained) + + configs["make_vocab_size_divisible_by"] = 1280 + configs["moe_router_score_function"] = "sigmoid" + configs["moe_router_enable_expert_bias"] = True + # aux_loss_alpha is not set in all DSv3 HF configs + if hasattr(hf_config, "aux_loss_alpha"): + configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha + + provider = KimiK25VLModelProvider( + # Text configuration + **configs, + # Vision configuration + _vision_config=vision_config, + # VL-specific token IDs + bos_token_id=getattr(text_config, "bos_token_id", 0), + eos_token_id=getattr(text_config, "eos_token_id", 1), + image_token_id=getattr(text_config, "media_placeholder_token_id", 151655), + # Precision configuration + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + # misc + hf_model_path=hf_pretrained._model_name_or_path, + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + mapping_list = get_common_mapping_list() + param_mappings = { + # expert bias + "decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias", + } + + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + for mapping in mapping_list: + # in HF Kimi K2.5 VL models, language component is prefixed with "language_model.model" instead of "model" + if isinstance(mapping, AutoMapping): + mapping.hf_param = "language_model." + mapping.hf_param + mapping.megatron_param = "language_model." + mapping.megatron_param + elif isinstance(mapping, GatedMLPMapping): + mapping.megatron_param = mapping.megatron_param.replace("decoder", "language_model.decoder") + mapping.hf_param["gate"] = mapping.hf_param["gate"].replace("model", "language_model.model") + mapping.hf_param["up"] = mapping.hf_param["up"].replace("model", "language_model.model") + + + # Add Vision and MM Projector mappings + mapping_list.extend( + [ + ReplicatedMapping( + megatron_param="vision_tower.**", + hf_param="vision_tower.**", + ), + ReplicatedMapping( + megatron_param="mm_projector.**", + hf_param="mm_projector.**", + ), + ] + ) + return MegatronMappingRegistry(*mapping_list) diff --git a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py new file mode 100644 index 0000000000..a931eb5650 --- /dev/null +++ b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +from transformers import Gemma3ForConditionalGeneration + +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge +from megatron.bridge.models.conversion.param_mapping import ( + AutoMapping, + GatedMLPMapping, + QKVMapping, + ReplicatedMapping, +) +from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM +from megatron.bridge.models.kimi_vl.kimi_k25_vl_provider import KimiK25VLModelProvider +from megatron.bridge.models.kimi_vl.modeling_kimi_k25_vl import KimiK25VLModel +from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list + +@MegatronModelBridge.register_bridge(source="KimiK25ForConditionalGeneration", target=KimiK25VLModel) +class KimiK25VLBridge(MegatronModelBridge): + """ + Megatron Bridge for Kimi K2.5 VL. + """ + + def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> KimiK25VLModelProvider: + hf_config = hf_pretrained.config + text_config = hf_config.text_config + vision_config = hf_config.vision_config + + # get_common_configs expects TextConfig + hf_pretrained.config = text_config + configs = get_common_configs(hf_pretrained) + + configs["make_vocab_size_divisible_by"] = 1280 + configs["moe_router_score_function"] = "sigmoid" + configs["moe_router_enable_expert_bias"] = True + # aux_loss_alpha is not set in all DSv3 HF configs + if hasattr(hf_config, "aux_loss_alpha"): + configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha + + provider = KimiK25VLModelProvider( + # Text configuration + **configs, + # Vision configuration + _vision_config=vision_config, + # VL-specific token IDs + bos_token_id=getattr(text_config, "bos_token_id", 0), + eos_token_id=getattr(text_config, "eos_token_id", 1), + image_token_id=getattr(text_config, "media_placeholder_token_id", 151655), + # Precision configuration + fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), + bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), + params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), + # misc + hf_model_path=hf_pretrained._model_name_or_path, + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format + # First create simple 1:1 parameter mappings using a dictionary for readability + mapping_list = get_common_mapping_list() + param_mappings = { + # expert bias + "decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias", + } + + for megatron_param, hf_param in param_mappings.items(): + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + for mapping in mapping_list: + # in HF Kimi K2.5 VL models, language component is prefixed with "language_model.model" instead of "model" + if isinstance(mapping, AutoMapping): + mapping.hf_param = "language_model." + mapping.hf_param + mapping.megatron_param = "language_model." + mapping.megatron_param + elif isinstance(mapping, GatedMLPMapping): + mapping.megatron_param = mapping.megatron_param.replace("decoder", "language_model.decoder") + mapping.hf_param["gate"] = mapping.hf_param["gate"].replace("model", "language_model.model") + mapping.hf_param["up"] = mapping.hf_param["up"].replace("model", "language_model.model") + + + # Add Vision and MM Projector mappings + mapping_list.extend( + [ + ReplicatedMapping( + megatron_param="vision_tower.**", + hf_param="vision_tower.**", + ), + ReplicatedMapping( + megatron_param="mm_projector.**", + hf_param="mm_projector.**", + ), + ] + ) + return MegatronMappingRegistry(*mapping_list) diff --git a/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py b/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py new file mode 100644 index 0000000000..16db9bc873 --- /dev/null +++ b/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from megatron.core.tensor_parallel.layers import ColumnParallelLinear +from megatron.core.transformer import TransformerConfig +from megatron.core.transformer.module import MegatronModule +from torch import Tensor +from transformers.dynamic_module_utils import get_class_from_dynamic_module + +from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync +from megatron.bridge.utils.import_utils import safe_import_from + + +TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm") + + +class KimiK25VLModel(MegatronModule): + """ + Kimi K2.5 Vision-Language (VL) model wrapper for Megatron. + Args: + config (GPTModelProvider): Model provider containing configuration for language and vision modules. + pre_process (bool, optional): Whether to construct the vision tower and projector. Default: True. + post_process (bool, optional): Whether to apply post-processing. Default: True. + vp_stage (Optional[int], optional): Pipeline stage for model parallelism. Default: None. + + Attributes: + pre_process (bool): If True, enables vision and multimodal components. + post_process (bool): If True, enables post-processing. + vp_stage (Optional[int]): Pipeline stage for model parallelism. + vision_tower (nn.Module): Vision encoder (MoonViT3d vision backbone). + mm_projector (nn.Module): PatchMergerMLP that projects vision features to language model space. + language_model (nn.Module): The underlying DeepSeek V3 language model. + get_image_features (callable): Method to extract and project image features. + + Forward Inputs: + input_ids (torch.LongTensor, optional): Tokenized input ids for the language model. + attention_mask (torch.Tensor, optional): Attention mask for the language model. + position_ids (torch.LongTensor, optional): Position ids for the language model. + inputs_embeds (torch.FloatTensor, optional): Precomputed input embeddings. + pixel_values (torch.Tensor, optional): Image tensor(s) for the vision tower. + labels (torch.Tensor, optional): Target labels for supervised training. + runtime_gather_output (bool, optional): If True, gather outputs across pipeline stages. + loss_mask (Tensor, optional): Mask for loss computation. + + Returns: + Tensor: Model output (e.g., logits or loss, depending on mode). + + Note: + - If `pre_process` is False, only the language model is constructed. + - The vision tower and projector are only active if `pre_process` is True. + - This class is intended for use within the Megatron-LM framework. + """ + + def __init__( + self, + config: GPTModelProvider, + pre_process: bool = True, + post_process: bool = True, + vp_stage: Optional[int] = None, + ) -> None: + super().__init__(config=config) + + self.pre_process = pre_process + self.post_process = post_process + self.vp_stage = vp_stage + + if config.hf_model_path is None: + raise ValueError("hf_model_path must be set.") + + KimiK25ForConditionalGeneration = get_class_from_dynamic_module( + "modeling_kimi_k25.KimiK25ForConditionalGeneration", + config.hf_model_path, + ) + if pre_process: + # Load vision tower and projector classes from the custom HuggingFace model code + MoonViT3dPretrainedModel = get_class_from_dynamic_module( + "modeling_kimi_k25.MoonViT3dPretrainedModel", + config.hf_model_path, + ) + PatchMergerMLP = get_class_from_dynamic_module( + "modeling_kimi_k25.PatchMergerMLP", + config.hf_model_path, + ) + ProjectorConfig = get_class_from_dynamic_module( + "modeling_kimi_k25.ProjectorConfig", + config.hf_model_path, + ) + VisionTowerConfig = get_class_from_dynamic_module( + "modeling_kimi_k25.VisionTowerConfig", + config.hf_model_path, + ) + self.vision_tower_config = VisionTowerConfig(config.vision_config) + self.projector_config = ProjectorConfig(config.vision_config) + self.vision_tower = MoonViT3dPretrainedModel(self.vision_tower_config) + self.mm_projector = PatchMergerMLP(self.projector_config) # TODO: support different types of mm projector + # Ensure HF visual tower params are marked for TP grad sync and future assignments are hooked. + hook_hf_module_setattr_for_tp_grad_sync(self.vision_tower) + hook_hf_module_setattr_for_tp_grad_sync(self.mm_projector) + self.language_model = self.config.provide_language_model( + pre_process=pre_process, post_process=post_process, vp_stage=vp_stage + ) + + # Finalize grad requires these to be bound with module + self.share_embeddings_and_output_weights = config.share_embeddings_and_output_weights + self.shared_embedding_or_output_weight = self.language_model.shared_embedding_or_output_weight + + self._extract_image_features = types.MethodType(KimiK25ForConditionalGeneration._extract_image_features, self) + self._merge_input_ids_with_image_features = types.MethodType(KimiK25ForConditionalGeneration._merge_input_ids_with_image_features, self) + + def set_input_tensor(self, input_tensor) -> None: + """Set model chunk input tensor.""" + self.language_model.set_input_tensor(input_tensor) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + runtime_gather_output: Optional[bool] = None, + *, + loss_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + input_ids: Tokenized input ids for the language model. + attention_mask: Attention mask for the language model. + position_ids: Position ids for the language model. + inputs_embeds: Precomputed input embeddings. + pixel_values: Image tensor for the vision tower. + grid_thws: Tensor of shape (num_images, 3) containing [temporal, height, width] + for each image's grid dimensions in the LLM. + labels: Target labels for supervised training. + runtime_gather_output: If True, gather outputs across pipeline stages. + loss_mask: Mask for loss computation. + """ + if self.pre_process: + if inputs_embeds is None: + inputs_embeds = self.language_model.embedding( + input_ids=input_ids, position_ids=None + ) # [decoder_seq_len, b, h_language] + + inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # [b, decoder_seq_len, h_language] + + breakpoint() + if pixel_values is not None: + image_features = self._extract_image_features(pixel_values, image_grid_thw) + image_features = self.mm_projector(image_features) + inputs_embeds = inputs_embeds.to(image_features[0].dtype) + inputs_embeds, attention_mask, labels, position_ids = ( + self._merge_input_ids_with_image_features( + image_features, + inputs_embeds, + input_ids, + attention_mask, + labels, + )) + + inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D) + + attention_mask = self._compute_attention_mask(input_ids) + + outputs = self.language_model.forward( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, # (B, 1, T, T) + decoder_input=inputs_embeds, # (T, B, D) + labels=labels, # (B, T) + loss_mask=loss_mask, + runtime_gather_output=runtime_gather_output, + ) + return outputs + + def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_vision_projection: bool): + """Freeze model modules. + + Make specific modules non-trainable by setting requires_grad to False. + + Args: + freeze_language_model (bool): Freeze the language model module. + freeze_vision_model (bool): Freeze the vision model module (patch_embed and blocks). + freeze_vision_projection (bool): Freeze the vision projection module (merger). + """ + modules = [] + + if freeze_language_model and hasattr(self, "language_model") and self.language_model is not None: + modules.append(self.language_model) + + if freeze_vision_model and hasattr(self, "vision_tower") and self.vision_tower is not None: + # Vision model consists of patch_embed and blocks + modules.append(self.vision_tower) + + if ( + freeze_vision_projection + and hasattr(self, "mm_projector") + and self.mm_projector is not None + ): + # Vision projection is the merger module + modules.append(self.mm_projector) + + for module in modules: + for param in module.parameters(): + param.requires_grad = False + + def _compute_attention_mask( + self, + input_ids: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if not self.pre_process: + return None + batch_size, seq_len = input_ids.shape + causal_mask = torch.tril(torch.ones((batch_size, 1, seq_len, seq_len))).to(input_ids.device) + + image_mask = input_ids == self.config.image_token_id + padded_mask = F.pad(image_mask, (1, 0), value=0) + boundary = padded_mask[:, 1:] > padded_mask[:, :-1] + numbered_boundary = torch.cumsum(boundary, dim=-1) + q_block_indices = image_mask * numbered_boundary + kv_block_indices = q_block_indices + bidirectional_mask = torch.logical_and( + kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1), + q_block_indices.unsqueeze(-1) > 0, + ) + # See te.DotProductAttention for the requirement of custom mask + attention_mask = ~torch.logical_or(causal_mask, bidirectional_mask.unsqueeze(1)) + return attention_mask + From da80f689e8a1be6c1f951bca386e2e7ae24350b2 Mon Sep 17 00:00:00 2001 From: Deyu Fu Date: Mon, 9 Feb 2026 13:47:10 +0800 Subject: [PATCH 2/3] Fix Kimi K2.5 VL: rewrite broken provider, fix config bugs, and correct forward pass - Rewrite kimi_k25_vl_provider.py (was a duplicate of bridge file) as a proper KimiK25VLModelProvider inheriting from KimiK2Provider - Fix aux_loss_alpha and media_placeholder_token_id read from wrong config - Fix token ID defaults to match Kimi K2.5 (not Qwen) - Fix forward pass: remove breakpoint(), skip transpose when no images, remove _compute_attention_mask (wrong shape after sequence expansion) - Add KimiK2Provider to models/__init__.py, remove unused imports Signed-off-by: Deyu Fu --- src/megatron/bridge/models/__init__.py | 3 + .../models/kimi_vl/kimi_k25_vl_bridge.py | 54 +++--- .../models/kimi_vl/kimi_k25_vl_provider.py | 178 +++++++++--------- .../models/kimi_vl/modeling_kimi_k25_vl.py | 53 ++---- 4 files changed, 133 insertions(+), 155 deletions(-) diff --git a/src/megatron/bridge/models/__init__.py b/src/megatron/bridge/models/__init__.py index bd99f59feb..efc95f512d 100644 --- a/src/megatron/bridge/models/__init__.py +++ b/src/megatron/bridge/models/__init__.py @@ -75,6 +75,9 @@ GPTOSSProvider120B, ) from megatron.bridge.models.gpt_provider import GPTModelProvider +from megatron.bridge.models.kimi import ( + KimiK2Provider, +) from megatron.bridge.models.kimi_vl import ( KimiK25VLBridge, KimiK25VLModelProvider, diff --git a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py index a931eb5650..0ec4286841 100644 --- a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py +++ b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py @@ -12,28 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import torch -from transformers import Gemma3ForConditionalGeneration from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge from megatron.bridge.models.conversion.param_mapping import ( AutoMapping, GatedMLPMapping, - QKVMapping, ReplicatedMapping, ) +from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.kimi_vl.kimi_k25_vl_provider import KimiK25VLModelProvider from megatron.bridge.models.kimi_vl.modeling_kimi_k25_vl import KimiK25VLModel -from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list + @MegatronModelBridge.register_bridge(source="KimiK25ForConditionalGeneration", target=KimiK25VLModel) class KimiK25VLBridge(MegatronModelBridge): """ Megatron Bridge for Kimi K2.5 VL. + + Converts HuggingFace Kimi K2.5 VL models (KimiK25ForConditionalGeneration) + to Megatron format (KimiK25VLModel) and vice versa. + + The language backbone shares the same architecture as Kimi K2 (MoE with MLA). """ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> KimiK25VLModelProvider: @@ -41,39 +43,45 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> KimiK25VLModelProvide text_config = hf_config.text_config vision_config = hf_config.vision_config - # get_common_configs expects TextConfig + # get_common_configs expects the text config hf_pretrained.config = text_config configs = get_common_configs(hf_pretrained) configs["make_vocab_size_divisible_by"] = 1280 configs["moe_router_score_function"] = "sigmoid" configs["moe_router_enable_expert_bias"] = True - # aux_loss_alpha is not set in all DSv3 HF configs - if hasattr(hf_config, "aux_loss_alpha"): - configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha - + # aux_loss_alpha is on text_config, not the top-level KimiK25Config + if hasattr(text_config, "aux_loss_alpha"): + configs["moe_aux_loss_coeff"] = text_config.aux_loss_alpha + + # media_placeholder_token_id is on the top-level KimiK25Config, not on text_config + media_placeholder_token_id = getattr(hf_config, "media_placeholder_token_id", 163605) + provider = KimiK25VLModelProvider( - # Text configuration + # Text configuration (extracted from HF text config) **configs, - # Vision configuration - _vision_config=vision_config, + # Vision configuration (raw HF KimiK25VisionConfig) + vision_config=vision_config, # VL-specific token IDs - bos_token_id=getattr(text_config, "bos_token_id", 0), - eos_token_id=getattr(text_config, "eos_token_id", 1), - image_token_id=getattr(text_config, "media_placeholder_token_id", 151655), + bos_token_id=getattr(text_config, "bos_token_id", 163584), + eos_token_id=getattr(text_config, "eos_token_id", 163585), + image_token_id=media_placeholder_token_id, + media_placeholder_token_id=media_placeholder_token_id, + pad_token_id=getattr(hf_config, "pad_token_id", 163839), + ignore_index=getattr(hf_config, "ignore_index", -100), # Precision configuration fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - # misc + # HF model path (needed for dynamic module loading of vision components) hf_model_path=hf_pretrained._model_name_or_path, ) return provider def mapping_registry(self) -> MegatronMappingRegistry: - # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format - # First create simple 1:1 parameter mappings using a dictionary for readability + # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format. + # Start with the common mapping list for the language model. mapping_list = get_common_mapping_list() param_mappings = { # expert bias @@ -83,8 +91,9 @@ def mapping_registry(self) -> MegatronMappingRegistry: for megatron_param, hf_param in param_mappings.items(): mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + # In HF Kimi K2.5 VL models, the language component is nested under + # "language_model.model" instead of just "model", so we need to add the prefix. for mapping in mapping_list: - # in HF Kimi K2.5 VL models, language component is prefixed with "language_model.model" instead of "model" if isinstance(mapping, AutoMapping): mapping.hf_param = "language_model." + mapping.hf_param mapping.megatron_param = "language_model." + mapping.megatron_param @@ -93,8 +102,9 @@ def mapping_registry(self) -> MegatronMappingRegistry: mapping.hf_param["gate"] = mapping.hf_param["gate"].replace("model", "language_model.model") mapping.hf_param["up"] = mapping.hf_param["up"].replace("model", "language_model.model") - - # Add Vision and MM Projector mappings + # Add Vision Tower and MM Projector mappings. + # These use ReplicatedMapping because the vision components are not sharded + # across tensor parallel ranks — they are replicated on each rank. mapping_list.extend( [ ReplicatedMapping( diff --git a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py index a931eb5650..a3b3618326 100644 --- a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py +++ b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_provider.py @@ -12,99 +12,95 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - -import torch -from transformers import Gemma3ForConditionalGeneration - -from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry -from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge -from megatron.bridge.models.conversion.param_mapping import ( - AutoMapping, - GatedMLPMapping, - QKVMapping, - ReplicatedMapping, -) -from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM -from megatron.bridge.models.kimi_vl.kimi_k25_vl_provider import KimiK25VLModelProvider -from megatron.bridge.models.kimi_vl.modeling_kimi_k25_vl import KimiK25VLModel -from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list - -@MegatronModelBridge.register_bridge(source="KimiK25ForConditionalGeneration", target=KimiK25VLModel) -class KimiK25VLBridge(MegatronModelBridge): +from dataclasses import dataclass +from typing import Any, Optional + +from megatron.core.models.gpt import GPTModel as MCoreGPTModel + +from megatron.bridge.models.kimi.kimi_provider import KimiK2Provider + + +@dataclass +class KimiK25VLModelProvider(KimiK2Provider): """ - Megatron Bridge for Kimi K2.5 VL. + Model provider for Kimi K2.5 VL (Vision-Language) Models. + + Inherits language model configuration from KimiK2Provider since the + Kimi K2.5 language backbone shares the same architecture as Kimi K2 + (MoE with MLA, 384 experts, 61 layers, etc.). + + Minor config differences (rotary_scaling_factor, layernorm_epsilon, + init_method_std) between K2 and K2.5 are handled at runtime by + ``get_common_configs()`` in the bridge, which reads actual values + from the HuggingFace config. + + The vision component (MoonViT3d + PatchMergerMLP) is dynamically loaded + from the HuggingFace model repository at runtime via ``trust_remote_code``. """ - def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> KimiK25VLModelProvider: - hf_config = hf_pretrained.config - text_config = hf_config.text_config - vision_config = hf_config.vision_config - - # get_common_configs expects TextConfig - hf_pretrained.config = text_config - configs = get_common_configs(hf_pretrained) - - configs["make_vocab_size_divisible_by"] = 1280 - configs["moe_router_score_function"] = "sigmoid" - configs["moe_router_enable_expert_bias"] = True - # aux_loss_alpha is not set in all DSv3 HF configs - if hasattr(hf_config, "aux_loss_alpha"): - configs["moe_aux_loss_coeff"] = hf_config.aux_loss_alpha - - provider = KimiK25VLModelProvider( - # Text configuration - **configs, - # Vision configuration - _vision_config=vision_config, - # VL-specific token IDs - bos_token_id=getattr(text_config, "bos_token_id", 0), - eos_token_id=getattr(text_config, "eos_token_id", 1), - image_token_id=getattr(text_config, "media_placeholder_token_id", 151655), - # Precision configuration - fp16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.float16), - bf16=(self.dtype_from_hf(hf_config, default=torch.float32) == torch.bfloat16), - params_dtype=self.dtype_from_hf(hf_config, default=torch.float32), - # misc - hf_model_path=hf_pretrained._model_name_or_path, - ) + # VL models shouldn't scatter embeddings across sequence parallel regions because + # the vision embeddings are going to be inserted into the language embeddings. + scatter_embedding_sequence_parallel: bool = False + + # Vision configuration — raw HF KimiK25VisionConfig object, used to construct + # VisionTowerConfig and ProjectorConfig for the vision tower and mm_projector. + vision_config: Any = None + + # Path to HuggingFace model directory (required for dynamic module loading + # of MoonViT3d, PatchMergerMLP, and other custom model classes). + hf_model_path: Optional[str] = None - return provider - - def mapping_registry(self) -> MegatronMappingRegistry: - # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format - # First create simple 1:1 parameter mappings using a dictionary for readability - mapping_list = get_common_mapping_list() - param_mappings = { - # expert bias - "decoder.layers.*.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias", - } - - for megatron_param, hf_param in param_mappings.items(): - mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) - - for mapping in mapping_list: - # in HF Kimi K2.5 VL models, language component is prefixed with "language_model.model" instead of "model" - if isinstance(mapping, AutoMapping): - mapping.hf_param = "language_model." + mapping.hf_param - mapping.megatron_param = "language_model." + mapping.megatron_param - elif isinstance(mapping, GatedMLPMapping): - mapping.megatron_param = mapping.megatron_param.replace("decoder", "language_model.decoder") - mapping.hf_param["gate"] = mapping.hf_param["gate"].replace("model", "language_model.model") - mapping.hf_param["up"] = mapping.hf_param["up"].replace("model", "language_model.model") - - - # Add Vision and MM Projector mappings - mapping_list.extend( - [ - ReplicatedMapping( - megatron_param="vision_tower.**", - hf_param="vision_tower.**", - ), - ReplicatedMapping( - megatron_param="mm_projector.**", - hf_param="mm_projector.**", - ), - ] + # Token IDs (from Kimi K2.5 config.json) + bos_token_id: int = 163584 + eos_token_id: int = 163585 + image_token_id: int = 163605 # media_placeholder_token_id in HF config + # Fields needed by HF's _merge_input_ids_with_image_features (bound via MethodType) + media_placeholder_token_id: int = 163605 + pad_token_id: int = 163839 + ignore_index: int = -100 + + # Freeze options for fine-tuning scenarios + freeze_language_model: bool = False + freeze_vision_model: bool = False + freeze_vision_projection: bool = False + + def provide(self, pre_process=None, post_process=None, vp_stage=None): + """ + Provide a KimiK25VL model instance with vision and language components. + + Returns: + KimiK25VLModel: Configured Kimi K2.5 VL model with vision tower, + multimodal projector, and Kimi K2 language model. + """ + from megatron.bridge.models.kimi_vl.modeling_kimi_k25_vl import KimiK25VLModel + + model = KimiK25VLModel( + self, pre_process=pre_process, post_process=post_process, vp_stage=vp_stage ) - return MegatronMappingRegistry(*mapping_list) + + # Apply freeze options if any are enabled + if self.freeze_language_model or self.freeze_vision_model or self.freeze_vision_projection: + model.freeze( + freeze_language_model=self.freeze_language_model, + freeze_vision_model=self.freeze_vision_model, + freeze_vision_projection=self.freeze_vision_projection, + ) + + return model + + def provide_language_model(self, pre_process=None, post_process=None, vp_stage=None) -> MCoreGPTModel: + """ + Provide just the language model component (Kimi K2 MoE) without vision. + + This is called by KimiK25VLModel to construct only the language backbone, + while the vision tower and projector are constructed separately. + + Args: + pre_process: Whether this is the first stage in pipeline parallelism. + post_process: Whether this is the last stage in pipeline parallelism. + vp_stage: Virtual pipeline stage number. + + Returns: + MCoreGPTModel instance (language model only). + """ + return super().provide(pre_process=pre_process, post_process=post_process, vp_stage=vp_stage) diff --git a/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py b/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py index 16db9bc873..f48614790f 100644 --- a/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py +++ b/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py @@ -13,24 +13,15 @@ # limitations under the License. import types -from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional import torch -import torch.nn as nn -import torch.nn.functional as F -from megatron.core.tensor_parallel.layers import ColumnParallelLinear -from megatron.core.transformer import TransformerConfig from megatron.core.transformer.module import MegatronModule from torch import Tensor from transformers.dynamic_module_utils import get_class_from_dynamic_module from megatron.bridge.models.gpt_provider import GPTModelProvider from megatron.bridge.utils.common_utils import hook_hf_module_setattr_for_tp_grad_sync -from megatron.bridge.utils.import_utils import safe_import_from - - -TENorm, _ = safe_import_from("megatron.core.extensions.transformer_engine", "TENorm") class KimiK25VLModel(MegatronModule): @@ -48,7 +39,7 @@ class KimiK25VLModel(MegatronModule): vp_stage (Optional[int]): Pipeline stage for model parallelism. vision_tower (nn.Module): Vision encoder (MoonViT3d vision backbone). mm_projector (nn.Module): PatchMergerMLP that projects vision features to language model space. - language_model (nn.Module): The underlying DeepSeek V3 language model. + language_model (nn.Module): The underlying Kimi K2 language model. get_image_features (callable): Method to extract and project image features. Forward Inputs: @@ -150,8 +141,9 @@ def forward( position_ids: Position ids for the language model. inputs_embeds: Precomputed input embeddings. pixel_values: Image tensor for the vision tower. - grid_thws: Tensor of shape (num_images, 3) containing [temporal, height, width] - for each image's grid dimensions in the LLM. + image_grid_thw: Tensor of shape ``(num_images, 3)`` containing ``[temporal, height, width]`` + for each image's grid dimensions in the LLM. This corresponds to ``grid_thws`` in + the HF Kimi K2.5 processor output. labels: Target labels for supervised training. runtime_gather_output: If True, gather outputs across pipeline stages. loss_mask: Mask for loss computation. @@ -160,12 +152,12 @@ def forward( if inputs_embeds is None: inputs_embeds = self.language_model.embedding( input_ids=input_ids, position_ids=None - ) # [decoder_seq_len, b, h_language] + ) # (T, B, D) — Megatron convention - inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # [b, decoder_seq_len, h_language] - - breakpoint() if pixel_values is not None: + # Transpose to (B, T, D) for HF's merge function which uses batch-first convention + inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (T, B, D) -> (B, T, D) + image_features = self._extract_image_features(pixel_values, image_grid_thw) image_features = self.mm_projector(image_features) inputs_embeds = inputs_embeds.to(image_features[0].dtype) @@ -178,9 +170,8 @@ def forward( labels, )) - inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D) - - attention_mask = self._compute_attention_mask(input_ids) + # Transpose back to (T, B, D) for Megatron language model + inputs_embeds = inputs_embeds.transpose(1, 0).contiguous() # (B, T, D) -> (T, B, D) outputs = self.language_model.forward( input_ids=None, @@ -224,26 +215,4 @@ def freeze(self, freeze_language_model: bool, freeze_vision_model: bool, freeze_ for param in module.parameters(): param.requires_grad = False - def _compute_attention_mask( - self, - input_ids: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - if not self.pre_process: - return None - batch_size, seq_len = input_ids.shape - causal_mask = torch.tril(torch.ones((batch_size, 1, seq_len, seq_len))).to(input_ids.device) - - image_mask = input_ids == self.config.image_token_id - padded_mask = F.pad(image_mask, (1, 0), value=0) - boundary = padded_mask[:, 1:] > padded_mask[:, :-1] - numbered_boundary = torch.cumsum(boundary, dim=-1) - q_block_indices = image_mask * numbered_boundary - kv_block_indices = q_block_indices - bidirectional_mask = torch.logical_and( - kv_block_indices[:, None, :] == q_block_indices.unsqueeze(-1), - q_block_indices.unsqueeze(-1) > 0, - ) - # See te.DotProductAttention for the requirement of custom mask - attention_mask = ~torch.logical_or(causal_mask, bidirectional_mask.unsqueeze(1)) - return attention_mask From 013ed718da6241bea17c9edae5f4cf41a0ac87e7 Mon Sep 17 00:00:00 2001 From: Deyu Fu Date: Wed, 11 Feb 2026 13:13:16 +0800 Subject: [PATCH 3/3] toy kimi k2.5 logits matches between mcore/hf --- .../compare_hf_and_megatron/compare.py | 281 ++++++++++++++---- examples/conversion/create_hf_toy_model.py | 207 +++++++++++-- .../conversion/hf_to_megatron_generate_vlm.py | 60 ++-- src/megatron/bridge/models/deepseek/common.py | 75 +++++ .../models/deepseek/deepseek_v3_bridge.py | 18 +- .../bridge/models/hf_pretrained/vlm.py | 11 + .../models/kimi_vl/kimi_k25_vl_bridge.py | 25 +- .../models/kimi_vl/modeling_kimi_k25_vl.py | 11 + .../deepseek/test_fp8_dequantization.py | 151 ++++++++++ 9 files changed, 734 insertions(+), 105 deletions(-) create mode 100644 tests/unit_tests/models/deepseek/test_fp8_dequantization.py diff --git a/examples/conversion/compare_hf_and_megatron/compare.py b/examples/conversion/compare_hf_and_megatron/compare.py index 60ce377cd1..aed1af974a 100644 --- a/examples/conversion/compare_hf_and_megatron/compare.py +++ b/examples/conversion/compare_hf_and_megatron/compare.py @@ -58,6 +58,20 @@ --megatron_model_path "/path/to/megatron/checkpoint" \ --prompt "Hello world" + # Kimi K2.5 VL comparison (text-only, single GPU): + uv run python examples/conversion/compare_hf_and_megatron/compare.py \ + --hf_model_path /path/to/Kimi-K2.5 \ + --prompt "Hello, how are you?" \ + --trust_remote_code + + # Kimi K2.5 VL comparison (with image, multi-GPU with expert parallelism): + torchrun --nproc_per_node=8 examples/conversion/compare_hf_and_megatron/compare.py \ + --hf_model_path /path/to/Kimi-K2.5 \ + --prompt "Describe this image." \ + --image_path /path/to/test_image.jpg \ + --trust_remote_code \ + --tp 1 --ep 8 + # Enable debug hooks to inspect forward pass intermediate results: uv run python examples/conversion/compare_hf_and_megatron/compare.py \ --hf_model_path "Qwen/Qwen3-1.7B" \ @@ -183,12 +197,13 @@ def load_model_class(model_class_name: str): raise ImportError(f"Could not import model class '{model_class_name}' from transformers") -def get_model_class(model_class_name: str = None, is_vl_model: bool = False): +def get_model_class(model_class_name: str = None, is_vl_model: bool = False, trust_remote_code: bool = False): """Get the appropriate model class for loading. Args: model_class_name: Optional specific model class name is_vl_model: Whether this is a vision-language model + trust_remote_code: Whether trust_remote_code is enabled (e.g., for Kimi K2.5) Returns: Model class to use for loading @@ -198,7 +213,7 @@ def get_model_class(model_class_name: str = None, is_vl_model: bool = False): return load_model_class(model_class_name) else: # Default behavior - if is_vl_model: + if is_vl_model and not trust_remote_code: print_rank_0( "Warning: VL model detected but no model class specified. Using AutoModelForCausalLM which may not work." ) @@ -243,6 +258,7 @@ def is_vision_language_model(model_path: str, trust_remote_code: bool | None = N "qwen2_vl", "qwen_vl", "minicpm", + "kimi_k25", ] return any(indicator in model_type or indicator in arch_str for indicator in vl_indicators) @@ -250,7 +266,7 @@ def is_vision_language_model(model_path: str, trust_remote_code: bool | None = N except Exception as e: print_rank_0(f"Warning: Could not determine model type from config: {e}") # Fallback: check if qwen_vl_utils is available and model name contains vl indicators - return any(indicator in model_path.lower() for indicator in ["vl", "vision"]) + return any(indicator in model_path.lower() for indicator in ["vl", "vision", "kimi"]) class SingleBatchIterator: @@ -363,6 +379,40 @@ def pad_input_ids_to_tp_multiple(input_ids, tp_size: int, pad_token_id: int = 0) return input_ids +def _is_kimi_processor(processor) -> bool: + """Check if the processor is a Kimi K2.5 processor.""" + return processor is not None and type(processor).__name__ == "KimiK25Processor" + + +def _generate_synthetic_vision_inputs(tokenizer, prompt: str, tp_size: int = 1): + """Create random pixel_values and grid_thws for VL testing without a processor. + + Generates a synthetic 4x4-patch "image" (patch_size=14 → 56x56 pixels) and + builds input_ids that contain the text prompt followed by a single image + placeholder token, which the model's merge function will expand. + + Returns the same tuple as process_inputs: (input_ids, pixel_values, grid_thws, messages). + """ + PATCH_SIZE = 14 + GRID_H, GRID_W, GRID_T = 4, 4, 1 + MEDIA_PLACEHOLDER_TOKEN_ID = 163605 + + total_patches = GRID_T * GRID_H * GRID_W + pixel_values = torch.randn(total_patches, 3, PATCH_SIZE, PATCH_SIZE, dtype=torch.bfloat16) + grid_thws = torch.tensor([[GRID_T, GRID_H, GRID_W]], dtype=torch.long) + + text_ids = tokenizer.encode(prompt, add_special_tokens=True) + ids = text_ids + [MEDIA_PLACEHOLDER_TOKEN_ID] + input_ids = torch.tensor([ids], dtype=torch.long) + input_ids = pad_input_ids_to_tp_multiple(input_ids, tp_size, tokenizer.pad_token_id or 0) + + print_rank_0( + f"Synthetic vision inputs: pixel_values={pixel_values.shape}, " + f"grid_thws={grid_thws.tolist()}, input_ids={input_ids.shape}" + ) + return input_ids, pixel_values, grid_thws, None + + def process_inputs(tokenizer, processor, image_path: Optional[str], prompt: str, is_vl_model: bool, tp_size: int = 1): """Process inputs for both vision-language and regular LLM models. @@ -376,39 +426,54 @@ def process_inputs(tokenizer, processor, image_path: Optional[str], prompt: str, Returns: Tuple of (input_ids, pixel_values, image_grid_thw, messages) + Note: For Kimi K2.5 models, image_grid_thw contains the ``grid_thws`` output + from the Kimi processor. """ if is_vl_model and image_path: - if not QWEN_VL_UTILS_AVAILABLE: - raise ImportError("qwen_vl_utils is required for vision-language models but not installed") - - # Create messages with image and text - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image": image_path}, - {"type": "text", "text": prompt}, - ], - } - ] - - # Process vision info - image_inputs, video_inputs = process_vision_info(messages) - - # Apply chat template - text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - # Process inputs - inputs = processor( - text=[text], - images=image_inputs, - videos=video_inputs, - padding=True, - return_tensors="pt", - ) + if _is_kimi_processor(processor): + # Kimi K2.5: use processor(messages=messages) directly + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image_url": image_path}, + {"type": "text", "text": prompt}, + ], + } + ] + inputs = processor(messages=messages) + input_ids = pad_input_ids_to_tp_multiple(inputs.input_ids, tp_size, tokenizer.pad_token_id or 0) + grid_thws = getattr(inputs, "grid_thws", None) + return input_ids, inputs.pixel_values, grid_thws, messages + elif QWEN_VL_UTILS_AVAILABLE and processor is not None: + # Qwen VL and other models: use process_vision_info + processor + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": prompt}, + ], + } + ] + + image_inputs, video_inputs = process_vision_info(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) - input_ids = pad_input_ids_to_tp_multiple(inputs.input_ids, tp_size, tokenizer.pad_token_id or 0) - return input_ids, inputs.pixel_values, inputs.image_grid_thw, messages + input_ids = pad_input_ids_to_tp_multiple(inputs.input_ids, tp_size, tokenizer.pad_token_id or 0) + return input_ids, inputs.pixel_values, inputs.image_grid_thw, messages + else: + # Processor unavailable -- generate synthetic vision inputs so we + # can still exercise the vision forward path with random data. + print_rank_0("Processor unavailable; generating synthetic vision inputs for testing.") + return _generate_synthetic_vision_inputs(tokenizer, prompt, tp_size) else: # Text-only processing for both VL models without images and regular LLMs if is_vl_model and processor: @@ -421,6 +486,60 @@ def process_inputs(tokenizer, processor, image_path: Optional[str], prompt: str, return input_ids, None, None, None +def _has_fp8_quantization(config) -> bool: + """Check if a config (or its text_config) specifies FP8 quantization.""" + for cfg in (config, getattr(config, "text_config", None)): + if cfg is None: + continue + qc = getattr(cfg, "quantization_config", None) + if qc is None: + continue + method = qc.get("quant_method", "") if isinstance(qc, dict) else getattr(qc, "quant_method", "") + if method == "fp8" or (isinstance(qc, dict) and qc.get("fmt") == "e4m3"): + return True + return False + + +def _load_hf_model_fp8(model_path, config, model_class, trust): + """Load an FP8-quantized HF model by dequantizing weights to bf16.""" + import copy + import glob + + from safetensors.torch import load_file + + from megatron.bridge.models.deepseek.common import maybe_dequantize_fp8_weight + + print_rank_0("Detected FP8 quantization; loading with manual dequantization...") + + stripped = copy.deepcopy(config) + for cfg in (stripped, getattr(stripped, "text_config", None)): + if cfg is not None and hasattr(cfg, "quantization_config"): + delattr(cfg, "quantization_config") + + model = model_class.from_config(stripped, torch_dtype=torch.bfloat16, trust_remote_code=trust) + + st_files = sorted(glob.glob(os.path.join(model_path, "*.safetensors"))) + raw_state: dict[str, torch.Tensor] = {} + for f in st_files: + raw_state.update(load_file(f, device="cpu")) + + dequantized: dict[str, torch.Tensor] = {} + for key, tensor in raw_state.items(): + if key.endswith("_scale_inv"): + continue + dequantized[key] = maybe_dequantize_fp8_weight(key, tensor, raw_state) + + missing, unexpected = model.load_state_dict(dequantized, strict=False) + if unexpected: + print_rank_0(f" Unexpected keys (ignored): {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") + if missing: + print_rank_0(f" Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") + + model = model.to(device="cuda", dtype=torch.bfloat16).eval() + print_rank_0(f"Loaded FP8 model (dequantized to bf16) with {model_class.__name__}") + return model + + def _load_hf_model(args, is_vl_model: bool): """Load HuggingFace model on rank 0. @@ -435,18 +554,21 @@ def _load_hf_model(args, is_vl_model: bool): return None print_rank_0("Loading HuggingFace model...") - model_class = get_model_class(args.model_class, is_vl_model) - hf_model = model_class.from_pretrained( - args.hf_model_path, - torch_dtype=torch.bfloat16, - device_map="cuda", - trust_remote_code=is_safe_repo( - trust_remote_code=args.trust_remote_code, - hf_path=args.hf_model_path, - ), - ) - hf_model = hf_model.eval() - print_rank_0(f"Loaded with {model_class.__name__}") + trust = is_safe_repo(trust_remote_code=args.trust_remote_code, hf_path=args.hf_model_path) + model_class = get_model_class(args.model_class, is_vl_model, trust_remote_code=bool(args.trust_remote_code)) + + config = AutoConfig.from_pretrained(args.hf_model_path, trust_remote_code=trust) + if _has_fp8_quantization(config): + hf_model = _load_hf_model_fp8(args.hf_model_path, config, model_class, trust) + else: + hf_model = model_class.from_pretrained( + args.hf_model_path, + torch_dtype=torch.bfloat16, + device_map="cuda", + trust_remote_code=trust, + ) + hf_model = hf_model.eval() + print_rank_0(f"Loaded with {model_class.__name__}") # Register debug hooks if enabled if args.enable_debug_hooks: @@ -457,6 +579,25 @@ def _load_hf_model(args, is_vl_model: bool): return hf_model +def _ensure_custom_code_files(source_dir: str, target_dir: str) -> None: + """Copy custom modeling .py files from source to target if missing. + + This is needed for round-trip loading of VL models with trust_remote_code, + since the exported checkpoint may not always contain the custom code files. + """ + import glob + import shutil + + source_dir = os.path.abspath(source_dir) + target_dir = os.path.abspath(target_dir) + if not os.path.isdir(source_dir): + return + for py_file in glob.glob(os.path.join(source_dir, "*.py")): + target_file = os.path.join(target_dir, os.path.basename(py_file)) + if not os.path.exists(target_file): + shutil.copy2(py_file, target_file) + + def _export_and_load_roundtrip_hf_model(args, is_vl_model: bool, megatron_model, bridge): """Export HF weights from Megatron model, save, and load exported HF model for comparison. @@ -476,7 +617,7 @@ def _export_and_load_roundtrip_hf_model(args, is_vl_model: bool, megatron_model, for name, param in bridge.export_hf_weights(megatron_model, show_progress=False): if _is_rank_0(): original_param = bridge.hf_pretrained.state[name] - if torch.allclose(param, original_param.to(param.device), atol=1e-1): + if torch.allclose(param.float(), original_param.to(param.device).float(), atol=1e-1): matches += 1 else: mismatches += 1 @@ -488,8 +629,13 @@ def _export_and_load_roundtrip_hf_model(args, is_vl_model: bool, megatron_model, # Load exported HF model only on rank 0 if _is_rank_0(): + # For trust_remote_code models (e.g. Kimi VL), ensure the exported + # directory contains the custom .py files needed by AutoModel dispatch. + if args.trust_remote_code: + _ensure_custom_code_files(args.hf_model_path, save_path) + print_rank_0("Loading exported HF model for comparison...") - model_class = get_model_class(args.model_class, is_vl_model) + model_class = get_model_class(args.model_class, is_vl_model, trust_remote_code=True) hf_model = model_class.from_pretrained( save_path, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True ).eval() @@ -501,7 +647,7 @@ def _export_and_load_roundtrip_hf_model(args, is_vl_model: bool, megatron_model, return None -def _run_hf_inference(hf_model, input_ids, pixel_values, image_grid_thw, tokenizer): +def _run_hf_inference(hf_model, input_ids, pixel_values, image_grid_thw, tokenizer, grid_key_name="image_grid_thw"): """Run HuggingFace model inference and return results. Args: @@ -510,6 +656,8 @@ def _run_hf_inference(hf_model, input_ids, pixel_values, image_grid_thw, tokeniz pixel_values: Pixel values for vision models (optional). image_grid_thw: Image grid dimensions (optional). tokenizer: Tokenizer for decoding. + grid_key_name: Name of the grid dimension kwarg for the HF model forward. + "image_grid_thw" for Qwen VL, "grid_thws" for Kimi K2.5. Returns: Tuple of (hf_logits, hf_next_token, hf_logits_stats, hf_top5_info, logits_shape). @@ -527,7 +675,7 @@ def _run_hf_inference(hf_model, input_ids, pixel_values, image_grid_thw, tokeniz if pixel_values is not None: hf_inputs["pixel_values"] = pixel_values if image_grid_thw is not None: - hf_inputs["image_grid_thw"] = image_grid_thw + hf_inputs[grid_key_name] = image_grid_thw hf_output = hf_model(**hf_inputs) @@ -572,7 +720,13 @@ def _load_megatron_model(args): if args.megatron_model_path: # Load from Megatron checkpoint - bridge = AutoBridge.from_hf_pretrained(args.hf_model_path) + bridge = AutoBridge.from_hf_pretrained( + args.hf_model_path, + trust_remote_code=is_safe_repo( + trust_remote_code=args.trust_remote_code, + hf_path=args.hf_model_path, + ), + ) model_provider = bridge.to_megatron_provider(load_weights=False) model_provider.tensor_model_parallel_size = tp model_provider.pipeline_model_parallel_size = pp @@ -652,8 +806,8 @@ def _setup_tokenizer_and_processor(args, is_vl_model: bool): ), ) except Exception as e: - print_rank_0(f"Warning: Could not load processor for VL model: {e}") - print_rank_0("Falling back to tokenizer-only mode") + print_rank_0(f"Warning: Could not load processor ({e})") + print_rank_0("Will use synthetic vision inputs if --image_path is provided") return tokenizer, processor @@ -702,16 +856,31 @@ def compare_models_one_step(args) -> None: # Move to GPU input_ids = input_ids.cuda() if pixel_values is not None: - pixel_values = pixel_values.cuda() + if isinstance(pixel_values, (list, tuple)): + pixel_values = [pv.cuda() for pv in pixel_values] + else: + pixel_values = pixel_values.cuda() if image_grid_thw is not None: - image_grid_thw = image_grid_thw.cuda() + if isinstance(image_grid_thw, (list, tuple)): + image_grid_thw = [g.cuda() for g in image_grid_thw] + else: + image_grid_thw = image_grid_thw.cuda() print_rank_0(f"Input shape: {input_ids.shape}") - print_rank_0(f"Pixel values shape: {pixel_values.shape if pixel_values is not None else 'None'}") + if pixel_values is not None: + pv_shape = [pv.shape for pv in pixel_values] if isinstance(pixel_values, (list, tuple)) else pixel_values.shape + print_rank_0(f"Pixel values shape: {pv_shape}") + else: + print_rank_0("Pixel values: None") + + # Determine grid key name for the HF model forward pass + # Kimi K2.5 uses "grid_thws", Qwen VL uses "image_grid_thw" + is_kimi = _is_kimi_processor(processor) or "kimi" in args.hf_model_path.lower() + grid_key_name = "grid_thws" if is_kimi else "image_grid_thw" # Run HF model forward pass hf_logits, hf_next_token, hf_logits_stats, hf_top5_info, logits_shape = _run_hf_inference( - hf_model, input_ids, pixel_values, image_grid_thw, tokenizer + hf_model, input_ids, pixel_values, image_grid_thw, tokenizer, grid_key_name=grid_key_name ) del hf_model diff --git a/examples/conversion/create_hf_toy_model.py b/examples/conversion/create_hf_toy_model.py index 0f1098cb6a..125b35b1b4 100644 --- a/examples/conversion/create_hf_toy_model.py +++ b/examples/conversion/create_hf_toy_model.py @@ -12,6 +12,7 @@ --output-dir /tmp/qwen3_toy \ --num-hidden-layers 2 \ --num-experts 4 + ``` The script works by: @@ -27,10 +28,13 @@ from __future__ import annotations import argparse +import json +import math from pathlib import Path from typing import Optional import torch +from safetensors.torch import load_file, save_file from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -85,6 +89,20 @@ def _parse_args() -> argparse.Namespace: default=1234, help="Torch seed applied before checkpoint creation.", ) + parser.add_argument( + "--quantize-fp8", + action="store_true", + default=False, + help="Post-process the saved checkpoint into FP8 (e4m3) block-wise " + "format with scale_inv tensors, matching the DeepSeek-V3 / Kimi-K2.5 " + "quantization convention.", + ) + parser.add_argument( + "--fp8-block-size", + type=int, + default=128, + help="Block size for FP8 block-wise quantization (default: 128).", + ) parser.add_argument( "--disable-remote-code-trust", action="store_false", @@ -103,32 +121,159 @@ def _adjust_config( num_experts_per_tok: Optional[int], moe_intermediate_size: Optional[int], ) -> None: - """Mutate the config in-place so it matches the requested toy topology.""" + """Mutate config(s) in-place so they match requested layer/expert topology.""" - config.num_hidden_layers = num_hidden_layers + def _adjust_one(cfg) -> None: + cfg.num_hidden_layers = num_hidden_layers - if hasattr(config, "max_window_layers"): - config.max_window_layers = min(config.max_window_layers, num_hidden_layers) + if hasattr(cfg, "max_window_layers"): + cfg.max_window_layers = min(cfg.max_window_layers, num_hidden_layers) - if hasattr(config, "layer_types"): - config.layer_types = config.layer_types[:num_hidden_layers] + if hasattr(cfg, "layer_types"): + cfg.layer_types = cfg.layer_types[:num_hidden_layers] - mlp_only_layers = getattr(config, "mlp_only_layers", []) - if isinstance(mlp_only_layers, (list, tuple)): - config.mlp_only_layers = [layer for layer in mlp_only_layers if layer < num_hidden_layers] + mlp_only_layers = getattr(cfg, "mlp_only_layers", []) + if isinstance(mlp_only_layers, (list, tuple)): + cfg.mlp_only_layers = [layer for layer in mlp_only_layers if layer < num_hidden_layers] + + # Kimi-style configs may use n_routed_experts while many others use num_experts. + for field in ("num_experts", "n_routed_experts"): + if hasattr(cfg, field): + setattr(cfg, field, num_experts) + + if hasattr(cfg, "num_experts_per_tok"): + cfg.num_experts_per_tok = ( + num_experts_per_tok + if num_experts_per_tok is not None + else min(num_experts, getattr(cfg, "num_experts_per_tok", num_experts)) + ) + + if hasattr(cfg, "router_top_k"): + cfg.router_top_k = min(num_experts, getattr(cfg, "num_experts_per_tok", num_experts)) + + if moe_intermediate_size is not None and hasattr(cfg, "moe_intermediate_size"): + cfg.moe_intermediate_size = moe_intermediate_size + + _adjust_one(config) + text_config = getattr(config, "text_config", None) + if text_config is not None: + _adjust_one(text_config) + + # Always strip quantization_config during model creation so + # from_config instantiates plain bf16 weights. If --quantize-fp8 is + # requested the checkpoint is post-processed later. + for cfg in (config, text_config): + if cfg is not None and hasattr(cfg, "quantization_config"): + del cfg.quantization_config + + +# FP8 e4m3 representable range +_FP8_E4M3_MAX = 448.0 + + +def _rebuild_safetensors_index(output_dir: Path, st_files: list[Path]) -> None: + """Regenerate model.safetensors.index.json from the current safetensors files.""" + index_path = output_dir / "model.safetensors.index.json" + if not index_path.exists(): + return + + weight_map: dict[str, str] = {} + metadata: dict[str, str] = {} + for st_path in st_files: + tensors = load_file(str(st_path), device="cpu") + for key in tensors: + weight_map[key] = st_path.name + total_bytes = sum(t.nelement() * t.element_size() for t in tensors.values()) + metadata[st_path.name] = str(total_bytes) + + index = {"metadata": {"total_size": sum(int(v) for v in metadata.values())}, "weight_map": weight_map} + index_path.write_text(json.dumps(index, indent=2) + "\n") + print(f" rebuilt {index_path.name} with {len(weight_map)} keys") - config.num_experts = num_experts - config.num_experts_per_tok = ( - num_experts_per_tok - if num_experts_per_tok is not None - else min(num_experts, getattr(config, "num_experts_per_tok", num_experts)) - ) - if hasattr(config, "router_top_k"): - config.router_top_k = min(config.num_experts, config.num_experts_per_tok) +def _quantize_checkpoint_fp8(output_dir: Path, block_size: int = 128) -> None: + """Convert saved bf16 safetensors in *output_dir* to FP8 block-wise format. - if moe_intermediate_size is not None: - config.moe_intermediate_size = moe_intermediate_size + For every 2-D weight tensor whose both dimensions are >= *block_size*, + produce: + - ``{name}`` in ``torch.float8_e4m3fn`` + - ``{name}_scale_inv`` with per-block dequantization scales (float32) + + Then inject a ``quantization_config`` into ``config.json``. + """ + st_files = sorted(output_dir.glob("*.safetensors")) + if not st_files: + print(" WARNING: no safetensors found; skipping FP8 quantization") + return + + for st_path in st_files: + tensors = load_file(str(st_path)) + new_tensors: dict[str, torch.Tensor] = {} + quantized_count = 0 + + for name, tensor in tensors.items(): + if tensor.ndim == 2 and tensor.shape[0] >= block_size and tensor.shape[1] >= block_size: + fp8_weight, scale_inv = _quantize_tensor_fp8(tensor.float(), block_size) + new_tensors[name] = fp8_weight + new_tensors[name + "_scale_inv"] = scale_inv + quantized_count += 1 + else: + new_tensors[name] = tensor + + save_file(new_tensors, str(st_path)) + print(f" {st_path.name}: quantized {quantized_count} tensors to FP8") + + # Rebuild the safetensors index so that _scale_inv keys are discoverable + # by lazy-loading state dict implementations (e.g. Megatron-Bridge). + _rebuild_safetensors_index(output_dir, st_files) + + config_path = output_dir / "config.json" + if config_path.exists(): + cfg = json.loads(config_path.read_text()) + quant_cfg = { + "quant_method": "fp8", + "fmt": "e4m3", + "weight_block_size": [block_size, block_size], + "activation_scheme": "dynamic", + } + cfg["quantization_config"] = quant_cfg + if "text_config" in cfg: + cfg["text_config"]["quantization_config"] = quant_cfg + config_path.write_text(json.dumps(cfg, indent=2) + "\n") + print(" injected quantization_config into config.json") + + +def _quantize_tensor_fp8( + tensor: torch.Tensor, block_size: int +) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a single 2-D tensor to FP8 e4m3 with per-block scales. + + Returns ``(fp8_weight, scale_inv)``.""" + M, N = tensor.shape + num_blocks_m = math.ceil(M / block_size) + num_blocks_n = math.ceil(N / block_size) + padded_M = num_blocks_m * block_size + padded_N = num_blocks_n * block_size + + if M != padded_M or N != padded_N: + padded = torch.zeros(padded_M, padded_N, dtype=tensor.dtype, device=tensor.device) + padded[:M, :N] = tensor + else: + padded = tensor + + blocks = padded.reshape(num_blocks_m, block_size, num_blocks_n, block_size) + abs_max = blocks.abs().amax(dim=(1, 3)) # [num_blocks_m, num_blocks_n] + scale_inv = (abs_max / _FP8_E4M3_MAX).clamp(min=1e-12).to(torch.float32) + + scaled = blocks / scale_inv[:, None, :, None] + scaled = scaled.clamp(-_FP8_E4M3_MAX, _FP8_E4M3_MAX) + scaled = scaled.reshape(padded_M, padded_N) + + if M != padded_M or N != padded_N: + scaled = scaled[:M, :N].contiguous() + + fp8_weight = scaled.to(torch.float8_e4m3fn) + return fp8_weight, scale_inv def _save_tokenizer(output_dir: Path, tokenizer_id: str, *, trust_remote_code: bool) -> None: @@ -136,6 +281,18 @@ def _save_tokenizer(output_dir: Path, tokenizer_id: str, *, trust_remote_code: b tokenizer.save_pretrained(output_dir) +def _save_processor(output_dir: Path, model_id: str, *, trust_remote_code: bool) -> None: + """Save the AutoProcessor alongside the model so VL toy models can process images.""" + try: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code) + processor.save_pretrained(output_dir) + print(f" Processor ({type(processor).__name__}) saved to {output_dir}") + except Exception as exc: + print(f" Processor not available for {model_id} ({exc}); skipping.") + + def main() -> None: """Main entry point.""" args = _parse_args() @@ -166,12 +323,22 @@ def main() -> None: model = model.bfloat16() model.save_pretrained(output_dir, safe_serialization=True) + if args.quantize_fp8: + print("Quantizing checkpoint to FP8 (e4m3) block-wise format...") + _quantize_checkpoint_fp8(output_dir, block_size=args.fp8_block_size) + _save_tokenizer(output_dir, tokenizer_id, trust_remote_code=trust_remote_code) + # For VL models, save the processor so image inputs work with the toy model. + if getattr(config, "vision_config", None) is not None: + _save_processor(output_dir, args.hf_model_id, trust_remote_code=trust_remote_code) + print(f"Toy HuggingFace checkpoint saved to: {output_dir}") print(f" hidden_layers={args.num_hidden_layers}") print(f" num_experts={args.num_experts}") - print(f" num_experts_per_tok={config.num_experts_per_tok}") + effective_cfg = getattr(config, "text_config", config) + print(f" num_experts_per_tok={getattr(effective_cfg, 'num_experts_per_tok', 'N/A')}") + print(f" quantize_fp8={args.quantize_fp8}") print(f" tokenizer_source={tokenizer_id}") diff --git a/examples/conversion/hf_to_megatron_generate_vlm.py b/examples/conversion/hf_to_megatron_generate_vlm.py index 2bf7888cfb..976c5149b5 100644 --- a/examples/conversion/hf_to_megatron_generate_vlm.py +++ b/examples/conversion/hf_to_megatron_generate_vlm.py @@ -147,32 +147,40 @@ def process_image_inputs(processor, image_path: Optional[str], prompt: str): Tuple of (input_ids, pixel_values, image_grid_thw, image_sizes, messages) """ if image_path: - # Create messages with image and text - messages = [ - { - "role": "user", - "content": [ - {"type": "image", "image_url": image_path}, - {"type": "text", "text": prompt}, - ], - } - ] - - # Process vision info - # image_inputs, video_inputs = process_vision_info(messages) - - # # Apply chat template - # text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - - # # Process inputs - # inputs = processor( - # text=[text], - # images=image_inputs, - # videos=video_inputs, - # padding=True, - # return_tensors="pt", - # ) - inputs = processor(messages=messages) + is_kimi = type(processor).__name__ == "KimiK25Processor" + + if is_kimi: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image_url": image_path}, + {"type": "text", "text": prompt}, + ], + } + ] + inputs = processor(messages=messages) + else: + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": prompt}, + ], + } + ] + + image_inputs, video_inputs = process_vision_info(messages) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + return ( inputs.input_ids, inputs.pixel_values, diff --git a/src/megatron/bridge/models/deepseek/common.py b/src/megatron/bridge/models/deepseek/common.py index bc34726afa..a3b2927c38 100644 --- a/src/megatron/bridge/models/deepseek/common.py +++ b/src/megatron/bridge/models/deepseek/common.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +from typing import Mapping + +import torch + from megatron.bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM @@ -144,3 +149,73 @@ def get_common_mapping_list() -> list: ) return mapping_list + + +def dequantize_fp8_blockwise( + weight: torch.Tensor, + scale_inv: torch.Tensor, + *, + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize an FP8 block-wise quantized weight tensor to higher precision. + + Block sizes are inferred from the shapes of *weight* and *scale_inv*: + ``block_m = ceil(M / scale_inv.shape[0])``, and likewise for the column + dimension. This matches the DeepSeek-V3 / Kimi-K2.5 FP8 convention where + ``weight_block_size = [128, 128]``. + + Args: + weight: FP8 weight tensor, shape ``[M, N]`` (``torch.float8_e4m3fn``). + scale_inv: Per-block inverse scale factors, shape + ``[ceil(M/block_m), ceil(N/block_n)]``. + dtype: Target output dtype (default ``torch.bfloat16``). + + Returns: + Dequantized tensor of shape ``[M, N]`` in *dtype*. + """ + M, N = weight.shape + scale_rows, scale_cols = scale_inv.shape + block_m = math.ceil(M / scale_rows) + block_n = math.ceil(N / scale_cols) + + padded_M = scale_rows * block_m + padded_N = scale_cols * block_n + + if M != padded_M or N != padded_N: + result = torch.zeros(padded_M, padded_N, dtype=dtype, device=weight.device) + result[:M, :N] = weight.to(dtype) + else: + result = weight.to(dtype) + + result = result.reshape(scale_rows, block_m, scale_cols, block_n) + result.mul_(scale_inv[:, None, :, None].to(dtype)) + result = result.reshape(padded_M, padded_N) + + if M != padded_M or N != padded_N: + result = result[:M, :N].contiguous() + return result + + +def maybe_dequantize_fp8_weight( + hf_param: str, + hf_weights: torch.Tensor, + hf_state_dict: Mapping[str, torch.Tensor], + dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Return *hf_weights* dequantized to *dtype* when FP8, otherwise pass through. + + Detection heuristic: the weight has ``float8_e4m3fn`` dtype **and** a + matching ``{hf_param}_scale_inv`` key exists in *hf_state_dict*. + """ + if not hasattr(torch, "float8_e4m3fn") or hf_weights.dtype != torch.float8_e4m3fn: + return hf_weights + + scale_inv_key = hf_param + "_scale_inv" + if scale_inv_key not in hf_state_dict: + return hf_weights + + return dequantize_fp8_blockwise( + hf_weights, + hf_state_dict[scale_inv_key], + dtype=dtype, + ) diff --git a/src/megatron/bridge/models/deepseek/deepseek_v3_bridge.py b/src/megatron/bridge/models/deepseek/deepseek_v3_bridge.py index 3ad7d53500..19da5df585 100644 --- a/src/megatron/bridge/models/deepseek/deepseek_v3_bridge.py +++ b/src/megatron/bridge/models/deepseek/deepseek_v3_bridge.py @@ -20,7 +20,11 @@ from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry from megatron.bridge.models.conversion.model_bridge import MegatronModelBridge, WeightConversionTask from megatron.bridge.models.conversion.param_mapping import AutoMapping -from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list +from megatron.bridge.models.deepseek.common import ( + get_common_configs, + get_common_mapping_list, + maybe_dequantize_fp8_weight, +) from megatron.bridge.models.deepseek.deepseek_provider import DeepSeekV3ModelProvider from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM @@ -58,6 +62,18 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> DeepSeekV3ModelP provider = DeepSeekV3ModelProvider(**configs) return provider + def maybe_modify_loaded_hf_weight( + self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor] + ) -> torch.Tensor: + """Load HF weights, dequantizing FP8 block-wise tensors to bf16 when present.""" + if isinstance(hf_param, str): + hf_weights = hf_state_dict[hf_param] + return maybe_dequantize_fp8_weight(hf_param, hf_weights, hf_state_dict) + return { + k: maybe_dequantize_fp8_weight(v, hf_state_dict[v], hf_state_dict) + for k, v in hf_param.items() + } + def mapping_registry(self) -> MegatronMappingRegistry: mapping_list = get_common_mapping_list() diff --git a/src/megatron/bridge/models/hf_pretrained/vlm.py b/src/megatron/bridge/models/hf_pretrained/vlm.py index 06c96e5721..0a45a2df31 100644 --- a/src/megatron/bridge/models/hf_pretrained/vlm.py +++ b/src/megatron/bridge/models/hf_pretrained/vlm.py @@ -322,6 +322,17 @@ def model_name_or_path(self) -> Optional[Union[str, Path]]: """Return the model name or path.""" return self._model_name_or_path + @property + def auto_map_model_class(self) -> Optional[str]: + """Get the custom model class string from the config's auto_map.""" + config = self.config + auto_map = getattr(config, "auto_map", None) + if auto_map: + for key in ("AutoModelForCausalLM", "AutoModel"): + if key in auto_map: + return str(auto_map[key]) + return None + @property def model(self) -> VLMType: """Lazy load and return the underlying model.""" diff --git a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py index 0ec4286841..17aa04c65d 100644 --- a/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py +++ b/src/megatron/bridge/models/kimi_vl/kimi_k25_vl_bridge.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Mapping + import torch from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry @@ -21,7 +23,11 @@ GatedMLPMapping, ReplicatedMapping, ) -from megatron.bridge.models.deepseek.common import get_common_configs, get_common_mapping_list +from megatron.bridge.models.deepseek.common import ( + get_common_configs, + get_common_mapping_list, + maybe_dequantize_fp8_weight, +) from megatron.bridge.models.hf_pretrained.vlm import PreTrainedVLM from megatron.bridge.models.kimi_vl.kimi_k25_vl_provider import KimiK25VLModelProvider from megatron.bridge.models.kimi_vl.modeling_kimi_k25_vl import KimiK25VLModel @@ -43,9 +49,12 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> KimiK25VLModelProvide text_config = hf_config.text_config vision_config = hf_config.vision_config - # get_common_configs expects the text config + # Temporarily swap to text_config for get_common_configs (which reads + # hf_pretrained.config), then restore the original VL config so that + # save_artifacts later writes the full config (including auto_map). hf_pretrained.config = text_config configs = get_common_configs(hf_pretrained) + hf_pretrained.config = hf_config configs["make_vocab_size_divisible_by"] = 1280 configs["moe_router_score_function"] = "sigmoid" @@ -79,6 +88,18 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> KimiK25VLModelProvide return provider + def maybe_modify_loaded_hf_weight( + self, hf_param: str | dict[str, str], hf_state_dict: Mapping[str, torch.Tensor] + ) -> torch.Tensor: + """Load HF weights, dequantizing FP8 block-wise tensors to bf16 when present.""" + if isinstance(hf_param, str): + hf_weights = hf_state_dict[hf_param] + return maybe_dequantize_fp8_weight(hf_param, hf_weights, hf_state_dict) + return { + k: maybe_dequantize_fp8_weight(v, hf_state_dict[v], hf_state_dict) + for k, v in hf_param.items() + } + def mapping_registry(self) -> MegatronMappingRegistry: # Return MegatronMappingRegistry containing parameter mappings from Megatron to HF format. # Start with the common mapping list for the language model. diff --git a/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py b/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py index f48614790f..d891d9c6ad 100644 --- a/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py +++ b/src/megatron/bridge/models/kimi_vl/modeling_kimi_k25_vl.py @@ -101,6 +101,17 @@ def __init__( ) self.vision_tower_config = VisionTowerConfig(config.vision_config) self.projector_config = ProjectorConfig(config.vision_config) + + # Patch: some versions of MoonViT3dEncoder.__init__ reference + # self.use_deterministic_attn before setting it. Inject a default + # via the class so the attribute lookup succeeds. + MoonViT3dEncoder = get_class_from_dynamic_module( + "modeling_kimi_k25.MoonViT3dEncoder", + config.hf_model_path, + ) + if not hasattr(MoonViT3dEncoder, "use_deterministic_attn"): + MoonViT3dEncoder.use_deterministic_attn = False + self.vision_tower = MoonViT3dPretrainedModel(self.vision_tower_config) self.mm_projector = PatchMergerMLP(self.projector_config) # TODO: support different types of mm projector # Ensure HF visual tower params are marked for TP grad sync and future assignments are hooked. diff --git a/tests/unit_tests/models/deepseek/test_fp8_dequantization.py b/tests/unit_tests/models/deepseek/test_fp8_dequantization.py new file mode 100644 index 0000000000..830d4cd032 --- /dev/null +++ b/tests/unit_tests/models/deepseek/test_fp8_dequantization.py @@ -0,0 +1,151 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for FP8 block-wise dequantization used by DeepSeek-V3 and Kimi-K2.5.""" + +import math + +import pytest +import torch + +from megatron.bridge.models.deepseek.common import ( + dequantize_fp8_blockwise, + maybe_dequantize_fp8_weight, +) + +_requires_fp8 = pytest.mark.skipif( + not hasattr(torch, "float8_e4m3fn"), + reason="torch.float8_e4m3fn not available in this PyTorch build", +) + +FP8_E4M3_MAX = 448.0 + + +def _quantize_to_fp8(tensor: torch.Tensor, block_size: int = 128): + """Reference quantization: bf16 -> (fp8, scale_inv). + + Mirrors the logic in ``create_hf_toy_model.py --quantize-fp8``. + """ + M, N = tensor.shape + num_blocks_m = math.ceil(M / block_size) + num_blocks_n = math.ceil(N / block_size) + padded_M = num_blocks_m * block_size + padded_N = num_blocks_n * block_size + + padded = torch.zeros(padded_M, padded_N, dtype=tensor.dtype) + padded[:M, :N] = tensor + + blocks = padded.reshape(num_blocks_m, block_size, num_blocks_n, block_size) + abs_max = blocks.abs().amax(dim=(1, 3)) + scale_inv = (abs_max / FP8_E4M3_MAX).clamp(min=1e-12).to(torch.float32) + + scaled = blocks / scale_inv[:, None, :, None] + scaled = scaled.clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX).reshape(padded_M, padded_N) + if M != padded_M or N != padded_N: + scaled = scaled[:M, :N].contiguous() + + fp8_weight = scaled.to(torch.float8_e4m3fn) + return fp8_weight, scale_inv + + +class TestDequantizeFp8Blockwise: + """Tests for ``dequantize_fp8_blockwise``.""" + + @_requires_fp8 + def test_roundtrip_divisible(self): + """Quantize -> dequantize round-trip with dimensions divisible by block_size.""" + original = torch.randn(256, 512, dtype=torch.float32) + fp8, scale_inv = _quantize_to_fp8(original, block_size=128) + + recovered = dequantize_fp8_blockwise(fp8, scale_inv, dtype=torch.float32) + + assert recovered.shape == original.shape + assert recovered.dtype == torch.float32 + # FP8 e4m3 has ~3 bits of mantissa; relative error should be small + rel_err = (recovered - original).abs() / (original.abs() + 1e-8) + assert rel_err.mean() < 0.05, f"Mean relative error too large: {rel_err.mean():.4f}" + + @_requires_fp8 + def test_roundtrip_non_divisible(self): + """Round-trip where dimensions are NOT multiples of block_size.""" + original = torch.randn(200, 300, dtype=torch.float32) + fp8, scale_inv = _quantize_to_fp8(original, block_size=128) + + recovered = dequantize_fp8_blockwise(fp8, scale_inv, dtype=torch.float32) + + assert recovered.shape == (200, 300) + rel_err = (recovered - original).abs() / (original.abs() + 1e-8) + assert rel_err.mean() < 0.05 + + @_requires_fp8 + def test_output_dtype(self): + """Output dtype matches the requested dtype.""" + original = torch.randn(128, 128, dtype=torch.float32) + fp8, scale_inv = _quantize_to_fp8(original, block_size=128) + + for dtype in (torch.bfloat16, torch.float32): + result = dequantize_fp8_blockwise(fp8, scale_inv, dtype=dtype) + assert result.dtype == dtype + + @_requires_fp8 + def test_zeros_preserved(self): + """All-zero weight stays zero after round-trip.""" + original = torch.zeros(128, 256, dtype=torch.float32) + fp8, scale_inv = _quantize_to_fp8(original, block_size=128) + + recovered = dequantize_fp8_blockwise(fp8, scale_inv, dtype=torch.float32) + assert torch.allclose(recovered, original, atol=1e-6) + + +class TestMaybeDequantizeFp8Weight: + """Tests for ``maybe_dequantize_fp8_weight`` (the bridge helper).""" + + @_requires_fp8 + def test_dequantizes_when_fp8_and_scale_present(self): + """FP8 weight + matching scale_inv -> dequantized output.""" + original = torch.randn(256, 256, dtype=torch.float32) + fp8, scale_inv = _quantize_to_fp8(original, block_size=128) + + state_dict = { + "layer.weight": fp8, + "layer.weight_scale_inv": scale_inv, + } + + result = maybe_dequantize_fp8_weight("layer.weight", fp8, state_dict) + + assert result.dtype == torch.bfloat16 + assert result.shape == (256, 256) + # Should be close to original + rel_err = (result.float() - original).abs() / (original.abs() + 1e-8) + assert rel_err.mean() < 0.05 + + @_requires_fp8 + def test_passthrough_when_no_scale_inv(self): + """FP8 weight without matching scale_inv is returned as-is.""" + fp8 = torch.zeros(128, 128, dtype=torch.float8_e4m3fn) + state_dict = {"layer.weight": fp8} + + result = maybe_dequantize_fp8_weight("layer.weight", fp8, state_dict) + assert result.dtype == torch.float8_e4m3fn + + def test_passthrough_when_not_fp8(self): + """Non-FP8 weight is returned unchanged regardless of scale_inv presence.""" + bf16_weight = torch.randn(128, 128, dtype=torch.bfloat16) + state_dict = { + "layer.weight": bf16_weight, + "layer.weight_scale_inv": torch.ones(1, 1), + } + + result = maybe_dequantize_fp8_weight("layer.weight", bf16_weight, state_dict) + assert result is bf16_weight