Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions examples/conversion/compare_hf_and_megatron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ def _load_megatron_model(args):
model_provider.finalize()
megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False)

for m in megatron_model:
m.config.mtp_num_layers = None
model_components = [m.eval() for m in megatron_model]

# Register debug hooks if enabled
Expand Down Expand Up @@ -715,17 +717,22 @@ def compare_models_one_step(args) -> None:
)

del hf_model
torch.cuda.empty_cache()

# Reload Megatron model to ensure a fresh instance before comparison
megatron_model, _ = _load_megatron_model(args)

# Broadcast HF results to all ranks after Megatron initialization
# (following the pattern from generate_from_hf.py)
if torch.distributed.is_initialized():
# Create tensors for broadcasting if they don't exist on non-rank-0
# Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model,
# so all ranks must use the same dtype for NCCL broadcast to work correctly.
if hf_logits is not None:
hf_logits = hf_logits.float()

if hf_next_token is None:
hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long)
if hf_logits is None:
# Get vocab size from tokenizer for proper tensor size
vocab_size = getattr(
tokenizer, "vocab_size", len(tokenizer.vocab) if hasattr(tokenizer, "vocab") else 32000
)
Expand All @@ -734,6 +741,8 @@ def compare_models_one_step(args) -> None:
# Broadcast from rank 0 to all ranks
torch.distributed.broadcast(hf_next_token, 0)
torch.distributed.broadcast(hf_logits, 0)
torch.distributed.barrier()
print_rank_0("HF results broadcast complete.")

# Run Megatron model forward pass
print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===")
Expand Down Expand Up @@ -790,27 +799,26 @@ def compare_models_one_step(args) -> None:
top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids]
print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}")

# Compare outputs (only where we have valid Megatron results)
# Megatron may pad vocab_size for GPU kernel efficiency — truncate
# to the HF vocab size so logits are directly comparable.
hf_vocab_size = hf_logits.shape[0]
megatron_logits_cmp = megatron_logits[:hf_vocab_size]
megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1)

# Compare outputs
print("=== COMPARISON ===")
token_match = hf_next_token.item() == megatron_next_token.item()
token_match = hf_next_token.item() == megatron_next_token_cmp.item()
token_status_emoji = "✅" if token_match else "❌"
print(f"Token match: {token_match} {token_status_emoji}")

# Compare logits if shapes match
if hf_logits.shape == megatron_logits.shape:
diff = (hf_logits - megatron_logits).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(
f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)"
)
else:
print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}")
print("Cannot compare logits directly due to shape mismatch")
diff = (hf_logits - megatron_logits_cmp).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)")

print("=== COMPARISON COMPLETE ===")
else:
Expand Down
Loading