Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions examples/conversion/compare_hf_and_megatron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ def _load_megatron_model(args):
model_provider.finalize()
megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False)

for m in megatron_model:
m.config.mtp_num_layers = None
model_components = [m.eval() for m in megatron_model]

# Register debug hooks if enabled
Expand Down Expand Up @@ -715,17 +717,19 @@ def compare_models_one_step(args) -> None:
)

del hf_model
# Reload Megatron model to ensure a fresh instance before comparison
megatron_model, _ = _load_megatron_model(args)
torch.cuda.empty_cache()

# Broadcast HF results to all ranks after Megatron initialization
# (following the pattern from generate_from_hf.py)
if torch.distributed.is_initialized():
# Create tensors for broadcasting if they don't exist on non-rank-0
# Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model,
# so all ranks must use the same dtype for NCCL broadcast to work correctly.
if hf_logits is not None:
hf_logits = hf_logits.float()

if hf_next_token is None:
hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long)
if hf_logits is None:
# Get vocab size from tokenizer for proper tensor size
vocab_size = getattr(
tokenizer, "vocab_size", len(tokenizer.vocab) if hasattr(tokenizer, "vocab") else 32000
)
Expand All @@ -734,6 +738,8 @@ def compare_models_one_step(args) -> None:
# Broadcast from rank 0 to all ranks
torch.distributed.broadcast(hf_next_token, 0)
torch.distributed.broadcast(hf_logits, 0)
torch.distributed.barrier()
print_rank_0("HF results broadcast complete.")

# Run Megatron model forward pass
print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===")
Expand Down Expand Up @@ -790,27 +796,26 @@ def compare_models_one_step(args) -> None:
top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids]
print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}")

# Compare outputs (only where we have valid Megatron results)
# Megatron may pad vocab_size for GPU kernel efficiency — truncate
# to the HF vocab size so logits are directly comparable.
hf_vocab_size = hf_logits.shape[0]
megatron_logits_cmp = megatron_logits[:hf_vocab_size]
megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1)

# Compare outputs
print("=== COMPARISON ===")
token_match = hf_next_token.item() == megatron_next_token.item()
token_match = hf_next_token.item() == megatron_next_token_cmp.item()
token_status_emoji = "✅" if token_match else "❌"
print(f"Token match: {token_match} {token_status_emoji}")

# Compare logits if shapes match
if hf_logits.shape == megatron_logits.shape:
diff = (hf_logits - megatron_logits).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(
f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)"
)
else:
print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}")
print("Cannot compare logits directly due to shape mismatch")
diff = (hf_logits - megatron_logits_cmp).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)")

print("=== COMPARISON COMPLETE ===")
else:
Expand Down
10 changes: 4 additions & 6 deletions examples/conversion/hf_megatron_roundtrip_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,12 @@ def main(
original_param = bridge.hf_pretrained.state[name]
compare_param = param
compare_original = original_param
# Cast to float32 for params with known dtype mismatches between Megatron and HF
# (e.g. Megatron keeps expert_bias in float32 while HF may use bfloat16)
if any(p in name for p in IGNORE_PRECISION_PARAMS):
# Cast to float32 when dtypes differ (e.g. fp8 HF weights vs bf16 Megatron,
# or Megatron keeping expert_bias in float32 while HF uses bfloat16)
if compare_param.dtype != compare_original.dtype or any(p in name for p in IGNORE_PRECISION_PARAMS):
compare_param = param.float()
compare_original = original_param.float()
match = torch.allclose(
compare_param, compare_original.to(compare_param.device), atol=1e-1
) # Increased tolerance for bfloat16
match = torch.allclose(compare_param, compare_original.to(compare_param.device), atol=1e-1)
all_match = all_match and match
table.add_row(
name,
Expand Down
44 changes: 44 additions & 0 deletions examples/models/minimax_m2/conversion.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -xeuo pipefail

# MiniMax-M2 (MoE: 256 experts, top-8, ~230GB fp8)
#
# Single-node (8 GPUs): use this script with TP*EP*PP <= 8.
# Multi-node (TP*EP*PP > 8): use slurm_conversion.sh instead.

WORKSPACE=${WORKSPACE:-/workspace}
MODEL_NAME=MiniMax-M2
HF_MODEL_ID=MiniMaxAI/$MODEL_NAME

# Multi-GPU round-trip on a single 8-GPU node (TP=2, EP=4 → 8 GPUs)
uv run python -m torch.distributed.run --nproc_per_node=8 \
examples/conversion/hf_megatron_roundtrip_multi_gpu.py \
--hf-model-id $HF_MODEL_ID \
--tp 2 --ep 4 \
--trust-remote-code

# Import HF → Megatron checkpoint (single-process)
uv run python examples/conversion/convert_checkpoints.py import \
--hf-model $HF_MODEL_ID \
--megatron-path ${WORKSPACE}/models/$MODEL_NAME \
--trust-remote-code

# Export Megatron → HF checkpoint (single-process)
uv run python examples/conversion/convert_checkpoints.py export \
--hf-model $HF_MODEL_ID \
--megatron-path ${WORKSPACE}/models/$MODEL_NAME/iter_0000000 \
--hf-path ${WORKSPACE}/models/$MODEL_NAME-hf-export
27 changes: 27 additions & 0 deletions examples/models/minimax_m2/inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# MiniMax-M2 (MoE: 256 experts, top-8, ~230GB fp8)
#
# Single-node (8 GPUs): use this script with TP*EP*PP <= 8.
# Multi-node (TP*EP*PP > 8): use slurm_inference.sh instead.

uv run python -m torch.distributed.run --nproc_per_node=8 \
examples/conversion/hf_to_megatron_generate_text.py \
--hf_model_path MiniMaxAI/MiniMax-M2 \
--prompt "What is artificial intelligence?" \
--max_new_tokens 100 \
--tp 2 --ep 4 \
--trust-remote-code
105 changes: 105 additions & 0 deletions examples/models/minimax_m2/slurm_conversion.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#!/bin/bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ==============================================================================
# MiniMax-M2 Checkpoint Conversion (Multi-Node via Slurm)
#
# MiniMax-M2 (MoE: 256 experts, top-8, ~230GB fp8)
# Use this script when TP * EP * PP > 8 (requires more than one 8-GPU node).
# For single-node (TP * EP * PP <= 8), use conversion.sh instead.
#
# Usage:
# 1. Modify the #SBATCH directives and CONFIGURATION section for your cluster
# 2. Submit: sbatch slurm_conversion.sh
# 3. Submit inference after conversion:
# sbatch --dependency=afterok:<job_id> slurm_inference.sh
# ==============================================================================

#SBATCH --job-name=minimax-m2-convert
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8
#SBATCH --time=4:00:00
#SBATCH --account=<your-account>
#SBATCH --output=logs/minimax_m2_convert_%j.out
#SBATCH --error=logs/minimax_m2_convert_%j.err
#SBATCH --exclusive

# ==============================================================================
# CONFIGURATION — edit these for your environment
# ==============================================================================

WORKSPACE=${WORKSPACE:-/workspace}
PROJECT_DIR=${PROJECT_DIR:-.}
MODEL_NAME=MiniMax-M2
HF_MODEL_ID=MiniMaxAI/$MODEL_NAME
GPUS_PER_NODE=8

TP=2
EP=8
PP=1

CONTAINER_IMAGE=${CONTAINER_IMAGE:?Set CONTAINER_IMAGE to your container path}
CONTAINER_MOUNTS="/lustre:/lustre,${PROJECT_DIR}:/opt/Megatron-Bridge"
CONTAINER_WORKDIR=/opt/Megatron-Bridge

# ==============================================================================
# Environment Setup
# ==============================================================================

export TORCH_NCCL_AVOID_RECORD_STREAMS=1
export NCCL_NVLS_ENABLE=0

# ==============================================================================
# Job Execution
# ==============================================================================

MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
MASTER_PORT=${MASTER_PORT:-29500}

echo "======================================"
echo "MiniMax-M2 Checkpoint Conversion"
echo "======================================"
echo "Job ID: $SLURM_JOB_ID"
echo "Nodes: $SLURM_JOB_NUM_NODES"
echo "GPUs per node: $GPUS_PER_NODE"
echo "Parallelism: TP=$TP, EP=$EP, PP=$PP"
echo "Total GPUs: $((TP * EP * PP))"
echo "Master: $MASTER_ADDR:$MASTER_PORT"
echo "======================================"

mkdir -p logs

SRUN_CMD="srun --ntasks-per-node=1 --no-container-mount-home \
--container-image=$CONTAINER_IMAGE \
--container-mounts=$CONTAINER_MOUNTS"

echo ""
echo "Importing HF -> Megatron checkpoint ..."
$SRUN_CMD bash -c "cd $CONTAINER_WORKDIR && \
if [ \$SLURM_LOCALID -eq 0 ]; then uv sync; else sleep 10; fi && \
uv run --no-sync python examples/conversion/convert_checkpoints.py import \
--hf-model $HF_MODEL_ID \
--megatron-path ${WORKSPACE}/models/$MODEL_NAME \
--trust-remote-code"
IMPORT_EXIT=$?
if [ $IMPORT_EXIT -ne 0 ]; then
echo "ERROR: Import failed (exit $IMPORT_EXIT)"
exit $IMPORT_EXIT
fi

echo "======================================"
echo "Conversion completed successfully"
echo "======================================"
113 changes: 113 additions & 0 deletions examples/models/minimax_m2/slurm_inference.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/bin/bash
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# ==============================================================================
# MiniMax-M2 Inference (Multi-Node via Slurm)
#
# MiniMax-M2 (MoE: 256 experts, top-8, ~230GB fp8)
# Use this script when TP * EP * PP > 8 (requires more than one 8-GPU node).
# For single-node (TP * EP * PP <= 8), use inference.sh instead.
#
# Usage:
# 1. Modify the #SBATCH directives and CONFIGURATION section for your cluster
# 2. Run conversion first: sbatch slurm_conversion.sh
# 3. Submit with dependency: sbatch --dependency=afterok:<convert_job_id> slurm_inference.sh
# ==============================================================================

#SBATCH --job-name=minimax-m2-inference
#SBATCH --nodes=8
#SBATCH --ntasks-per-node=1
#SBATCH --gpus-per-node=8
#SBATCH --time=4:00:00
#SBATCH --account=<your-account>
#SBATCH --output=logs/minimax_m2_inference_%j.out
#SBATCH --error=logs/minimax_m2_inference_%j.err
#SBATCH --exclusive

# ==============================================================================
# CONFIGURATION — edit these for your environment
# ==============================================================================

WORKSPACE=${WORKSPACE:-/workspace}
PROJECT_DIR=${PROJECT_DIR:-.}
MODEL_NAME=MiniMax-M2
HF_MODEL_ID=MiniMaxAI/$MODEL_NAME
MEGATRON_CKPT=${WORKSPACE}/models/${MODEL_NAME}/iter_0000000
GPUS_PER_NODE=8
PROMPT="What is artificial intelligence?"
MAX_NEW_TOKENS=100

# MiniMax-M2 needs EP=32 (8 nodes) to fit 256 experts in memory.
# Increasing TP does NOT reduce expert memory — increase EP instead.
TP=2
EP=32
PP=1

CONTAINER_IMAGE=${CONTAINER_IMAGE:?Set CONTAINER_IMAGE to your container path}
CONTAINER_MOUNTS="/lustre:/lustre,${PROJECT_DIR}:/opt/Megatron-Bridge"
CONTAINER_WORKDIR=/opt/Megatron-Bridge

# ==============================================================================
# Environment Setup
# ==============================================================================

export TORCH_NCCL_AVOID_RECORD_STREAMS=1
export NCCL_NVLS_ENABLE=0

# ==============================================================================
# Job Execution
# ==============================================================================

MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
MASTER_PORT=${MASTER_PORT:-29500}

echo "======================================"
echo "MiniMax-M2 Inference"
echo "======================================"
echo "Job ID: $SLURM_JOB_ID"
echo "Nodes: $SLURM_JOB_NUM_NODES"
echo "GPUs per node: $GPUS_PER_NODE"
echo "Parallelism: TP=$TP, EP=$EP, PP=$PP"
echo "Total GPUs: $((TP * EP * PP))"
echo "Master: $MASTER_ADDR:$MASTER_PORT"
echo "======================================"

mkdir -p logs

SRUN_CMD="srun --ntasks-per-node=1 --no-container-mount-home \
--container-image=$CONTAINER_IMAGE \
--container-mounts=$CONTAINER_MOUNTS"

echo ""
echo "Running inference ..."
$SRUN_CMD bash -c "cd $CONTAINER_WORKDIR && \
if [ \$SLURM_LOCALID -eq 0 ]; then uv sync; else sleep 10; fi && \
uv run --no-sync python -m torch.distributed.run \
--nnodes=\$SLURM_JOB_NUM_NODES \
--nproc_per_node=$GPUS_PER_NODE \
--node_rank=\$SLURM_NODEID \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
examples/conversion/hf_to_megatron_generate_text.py \
--hf_model_path $HF_MODEL_ID \
--megatron_model_path $MEGATRON_CKPT \
--prompt '$PROMPT' \
--max_new_tokens $MAX_NEW_TOKENS \
--tp $TP --ep $EP \
--trust-remote-code"

echo "======================================"
echo "Inference completed"
echo "======================================"
Loading