Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 225 additions & 56 deletions examples/conversion/compare_hf_and_megatron/compare.py

Large diffs are not rendered by default.

207 changes: 187 additions & 20 deletions examples/conversion/create_hf_toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
--output-dir /tmp/qwen3_toy \
--num-hidden-layers 2 \
--num-experts 4

```

The script works by:
Expand All @@ -27,10 +28,13 @@
from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from typing import Optional

import torch
from safetensors.torch import load_file, save_file
from transformers import (
AutoConfig,
AutoModelForCausalLM,
Expand Down Expand Up @@ -85,6 +89,20 @@ def _parse_args() -> argparse.Namespace:
default=1234,
help="Torch seed applied before checkpoint creation.",
)
parser.add_argument(
"--quantize-fp8",
action="store_true",
default=False,
help="Post-process the saved checkpoint into FP8 (e4m3) block-wise "
"format with scale_inv tensors, matching the DeepSeek-V3 / Kimi-K2.5 "
"quantization convention.",
)
parser.add_argument(
"--fp8-block-size",
type=int,
default=128,
help="Block size for FP8 block-wise quantization (default: 128).",
)
parser.add_argument(
"--disable-remote-code-trust",
action="store_false",
Expand All @@ -103,39 +121,178 @@ def _adjust_config(
num_experts_per_tok: Optional[int],
moe_intermediate_size: Optional[int],
) -> None:
"""Mutate the config in-place so it matches the requested toy topology."""
"""Mutate config(s) in-place so they match requested layer/expert topology."""

config.num_hidden_layers = num_hidden_layers
def _adjust_one(cfg) -> None:
cfg.num_hidden_layers = num_hidden_layers

if hasattr(config, "max_window_layers"):
config.max_window_layers = min(config.max_window_layers, num_hidden_layers)
if hasattr(cfg, "max_window_layers"):
cfg.max_window_layers = min(cfg.max_window_layers, num_hidden_layers)

if hasattr(config, "layer_types"):
config.layer_types = config.layer_types[:num_hidden_layers]
if hasattr(cfg, "layer_types"):
cfg.layer_types = cfg.layer_types[:num_hidden_layers]

mlp_only_layers = getattr(config, "mlp_only_layers", [])
if isinstance(mlp_only_layers, (list, tuple)):
config.mlp_only_layers = [layer for layer in mlp_only_layers if layer < num_hidden_layers]
mlp_only_layers = getattr(cfg, "mlp_only_layers", [])
if isinstance(mlp_only_layers, (list, tuple)):
cfg.mlp_only_layers = [layer for layer in mlp_only_layers if layer < num_hidden_layers]

# Kimi-style configs may use n_routed_experts while many others use num_experts.
for field in ("num_experts", "n_routed_experts"):
if hasattr(cfg, field):
setattr(cfg, field, num_experts)

if hasattr(cfg, "num_experts_per_tok"):
cfg.num_experts_per_tok = (
num_experts_per_tok
if num_experts_per_tok is not None
else min(num_experts, getattr(cfg, "num_experts_per_tok", num_experts))
)

if hasattr(cfg, "router_top_k"):
cfg.router_top_k = min(num_experts, getattr(cfg, "num_experts_per_tok", num_experts))

if moe_intermediate_size is not None and hasattr(cfg, "moe_intermediate_size"):
cfg.moe_intermediate_size = moe_intermediate_size

_adjust_one(config)
text_config = getattr(config, "text_config", None)
if text_config is not None:
_adjust_one(text_config)

# Always strip quantization_config during model creation so
# from_config instantiates plain bf16 weights. If --quantize-fp8 is
# requested the checkpoint is post-processed later.
for cfg in (config, text_config):
if cfg is not None and hasattr(cfg, "quantization_config"):
del cfg.quantization_config


# FP8 e4m3 representable range
_FP8_E4M3_MAX = 448.0


def _rebuild_safetensors_index(output_dir: Path, st_files: list[Path]) -> None:
"""Regenerate model.safetensors.index.json from the current safetensors files."""
index_path = output_dir / "model.safetensors.index.json"
if not index_path.exists():
return

weight_map: dict[str, str] = {}
metadata: dict[str, str] = {}
for st_path in st_files:
tensors = load_file(str(st_path), device="cpu")
for key in tensors:
weight_map[key] = st_path.name
total_bytes = sum(t.nelement() * t.element_size() for t in tensors.values())
metadata[st_path.name] = str(total_bytes)

index = {"metadata": {"total_size": sum(int(v) for v in metadata.values())}, "weight_map": weight_map}
index_path.write_text(json.dumps(index, indent=2) + "\n")
print(f" rebuilt {index_path.name} with {len(weight_map)} keys")

config.num_experts = num_experts
config.num_experts_per_tok = (
num_experts_per_tok
if num_experts_per_tok is not None
else min(num_experts, getattr(config, "num_experts_per_tok", num_experts))
)

if hasattr(config, "router_top_k"):
config.router_top_k = min(config.num_experts, config.num_experts_per_tok)
def _quantize_checkpoint_fp8(output_dir: Path, block_size: int = 128) -> None:
"""Convert saved bf16 safetensors in *output_dir* to FP8 block-wise format.

if moe_intermediate_size is not None:
config.moe_intermediate_size = moe_intermediate_size
For every 2-D weight tensor whose both dimensions are >= *block_size*,
produce:
- ``{name}`` in ``torch.float8_e4m3fn``
- ``{name}_scale_inv`` with per-block dequantization scales (float32)

Then inject a ``quantization_config`` into ``config.json``.
"""
st_files = sorted(output_dir.glob("*.safetensors"))
if not st_files:
print(" WARNING: no safetensors found; skipping FP8 quantization")
return

for st_path in st_files:
tensors = load_file(str(st_path))
new_tensors: dict[str, torch.Tensor] = {}
quantized_count = 0

for name, tensor in tensors.items():
if tensor.ndim == 2 and tensor.shape[0] >= block_size and tensor.shape[1] >= block_size:
fp8_weight, scale_inv = _quantize_tensor_fp8(tensor.float(), block_size)
new_tensors[name] = fp8_weight
new_tensors[name + "_scale_inv"] = scale_inv
quantized_count += 1
else:
new_tensors[name] = tensor

save_file(new_tensors, str(st_path))
print(f" {st_path.name}: quantized {quantized_count} tensors to FP8")

# Rebuild the safetensors index so that _scale_inv keys are discoverable
# by lazy-loading state dict implementations (e.g. Megatron-Bridge).
_rebuild_safetensors_index(output_dir, st_files)

config_path = output_dir / "config.json"
if config_path.exists():
cfg = json.loads(config_path.read_text())
quant_cfg = {
"quant_method": "fp8",
"fmt": "e4m3",
"weight_block_size": [block_size, block_size],
"activation_scheme": "dynamic",
}
cfg["quantization_config"] = quant_cfg
if "text_config" in cfg:
cfg["text_config"]["quantization_config"] = quant_cfg
config_path.write_text(json.dumps(cfg, indent=2) + "\n")
print(" injected quantization_config into config.json")


def _quantize_tensor_fp8(
tensor: torch.Tensor, block_size: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize a single 2-D tensor to FP8 e4m3 with per-block scales.

Returns ``(fp8_weight, scale_inv)``."""
M, N = tensor.shape
num_blocks_m = math.ceil(M / block_size)
num_blocks_n = math.ceil(N / block_size)
padded_M = num_blocks_m * block_size
padded_N = num_blocks_n * block_size

if M != padded_M or N != padded_N:
padded = torch.zeros(padded_M, padded_N, dtype=tensor.dtype, device=tensor.device)
padded[:M, :N] = tensor
else:
padded = tensor

blocks = padded.reshape(num_blocks_m, block_size, num_blocks_n, block_size)
abs_max = blocks.abs().amax(dim=(1, 3)) # [num_blocks_m, num_blocks_n]
scale_inv = (abs_max / _FP8_E4M3_MAX).clamp(min=1e-12).to(torch.float32)

scaled = blocks / scale_inv[:, None, :, None]
scaled = scaled.clamp(-_FP8_E4M3_MAX, _FP8_E4M3_MAX)
scaled = scaled.reshape(padded_M, padded_N)

if M != padded_M or N != padded_N:
scaled = scaled[:M, :N].contiguous()

fp8_weight = scaled.to(torch.float8_e4m3fn)
return fp8_weight, scale_inv


def _save_tokenizer(output_dir: Path, tokenizer_id: str, *, trust_remote_code: bool) -> None:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=trust_remote_code)
tokenizer.save_pretrained(output_dir)


def _save_processor(output_dir: Path, model_id: str, *, trust_remote_code: bool) -> None:
"""Save the AutoProcessor alongside the model so VL toy models can process images."""
try:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code)
processor.save_pretrained(output_dir)
print(f" Processor ({type(processor).__name__}) saved to {output_dir}")
except Exception as exc:
print(f" Processor not available for {model_id} ({exc}); skipping.")


def main() -> None:
"""Main entry point."""
args = _parse_args()
Expand Down Expand Up @@ -166,12 +323,22 @@ def main() -> None:
model = model.bfloat16()
model.save_pretrained(output_dir, safe_serialization=True)

if args.quantize_fp8:
print("Quantizing checkpoint to FP8 (e4m3) block-wise format...")
_quantize_checkpoint_fp8(output_dir, block_size=args.fp8_block_size)

_save_tokenizer(output_dir, tokenizer_id, trust_remote_code=trust_remote_code)

# For VL models, save the processor so image inputs work with the toy model.
if getattr(config, "vision_config", None) is not None:
_save_processor(output_dir, args.hf_model_id, trust_remote_code=trust_remote_code)

print(f"Toy HuggingFace checkpoint saved to: {output_dir}")
print(f" hidden_layers={args.num_hidden_layers}")
print(f" num_experts={args.num_experts}")
print(f" num_experts_per_tok={config.num_experts_per_tok}")
effective_cfg = getattr(config, "text_config", config)
print(f" num_experts_per_tok={getattr(effective_cfg, 'num_experts_per_tok', 'N/A')}")
print(f" quantize_fp8={args.quantize_fp8}")
print(f" tokenizer_source={tokenizer_id}")


Expand Down
65 changes: 37 additions & 28 deletions examples/conversion/hf_to_megatron_generate_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,35 +147,44 @@ def process_image_inputs(processor, image_path: Optional[str], prompt: str):
Tuple of (input_ids, pixel_values, image_grid_thw, image_sizes, messages)
"""
if image_path:
# Create messages with image and text
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": prompt},
],
}
]

# Process vision info
image_inputs, video_inputs = process_vision_info(messages)

# Apply chat template
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Process inputs
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
is_kimi = type(processor).__name__ == "KimiK25Processor"

if is_kimi:
messages = [
{
"role": "user",
"content": [
{"type": "image", "image_url": image_path},
{"type": "text", "text": prompt},
],
}
]
inputs = processor(messages=messages)
else:
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_path},
{"type": "text", "text": prompt},
],
}
]

image_inputs, video_inputs = process_vision_info(messages)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)

return (
inputs.input_ids,
inputs.pixel_values,
getattr(inputs, "image_grid_thw", None),
getattr(inputs, "image_grid_thw", None) or getattr(inputs, "grid_thws", None),
getattr(inputs, "image_sizes", None),
messages,
)
Expand Down Expand Up @@ -209,7 +218,7 @@ def main(args) -> None:

# We still need HF config for tokenizer, but we'll load the model from Megatron checkpoint
# Create bridge from HF config only (no weights)
bridge = AutoBridge.from_hf_pretrained(args.hf_model_path)
bridge = AutoBridge.from_hf_pretrained(args.hf_model_path, trust_remote_code=args.trust_remote_code)

# Initialize model parallel before loading
model_provider = bridge.to_megatron_provider(load_weights=False)
Expand All @@ -236,7 +245,7 @@ def main(args) -> None:
else:
# Load from HuggingFace and convert to Megatron
print_rank_0(f"Loading HuggingFace model from: {args.hf_model_path}")
bridge = AutoBridge.from_hf_pretrained(args.hf_model_path)
bridge = AutoBridge.from_hf_pretrained(args.hf_model_path, trust_remote_code=args.trust_remote_code)
model_provider = bridge.to_megatron_provider(load_weights=True)
model_provider.tensor_model_parallel_size = tp
model_provider.pipeline_model_parallel_size = pp
Expand Down
12 changes: 12 additions & 0 deletions src/megatron/bridge/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@
GPTOSSProvider120B,
)
from megatron.bridge.models.gpt_provider import GPTModelProvider
from megatron.bridge.models.kimi import (
KimiK2Provider,
)
from megatron.bridge.models.kimi_vl import (
KimiK25VLBridge,
KimiK25VLModelProvider,
KimiK25VLModel,
)
from megatron.bridge.models.llama import (
CodeLlamaModelProvider7B,
CodeLlamaModelProvider13B,
Expand Down Expand Up @@ -229,6 +237,10 @@
"GPTOSSProvider20B",
"GPTOSSProvider120B",
"T5ModelProvider",
"KimiK2Provider",
"KimiK25VLModel",
"KimiK25VLBridge",
"KimiK25VLModelProvider",
"LlamaModelProvider",
"Llama2ModelProvider7B",
"Llama2ModelProvider13B",
Expand Down
5 changes: 4 additions & 1 deletion src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,9 +732,12 @@ def import_ckpt(
megatron_model = bridge.to_megatron_model(wrap_with_ddp=False, use_cpu_initialization=True)

# Save as Megatron checkpoint
hf_tokenizer_kwargs = None
hf_tokenizer_kwargs = {}
if hasattr(bridge._model_bridge, "get_hf_tokenizer_kwargs"):
hf_tokenizer_kwargs = bridge._model_bridge.get_hf_tokenizer_kwargs()
# Pass trust_remote_code to tokenizer if provided in kwargs
if kwargs.get("trust_remote_code"):
hf_tokenizer_kwargs["trust_remote_code"] = True
bridge.save_megatron_model(
megatron_model,
megatron_path,
Expand Down
Loading
Loading