Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 27 additions & 39 deletions examples/conversion/compare_hf_and_megatron/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
"""

import argparse
import gc
import importlib
import os
import sys
Expand Down Expand Up @@ -319,13 +318,7 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor:
def loss_func(x, **kwargs):
return x

model_output = model(**forward_args)
if isinstance(model_output, tuple):
output_tensor, _ = model_output
else:
output_tensor = model_output

return output_tensor, loss_func
return model(**forward_args), loss_func
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Handle tuple model outputs before returning from vlm_forward_step.

At Line 321, returning raw model(**forward_args) can propagate a tuple output (e.g., (output_tensor, loss_mask)), while downstream code assumes a tensor and will fail on tensor ops/indexing.

Proposed fix
-    return model(**forward_args), loss_func
+    model_output = model(**forward_args)
+    if isinstance(model_output, tuple):
+        output_tensor, _ = model_output
+    else:
+        output_tensor = model_output
+    return output_tensor, loss_func
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/conversion/compare_hf_and_megatron/compare.py` at line 321,
vlm_forward_step currently returns model(**forward_args) which may be a tuple
(e.g., (output_tensor, loss_mask)) and breaks downstream tensor ops; update
vlm_forward_step to detect if model(**forward_args) is a tuple or list and
extract the primary output tensor (e.g., first element) before returning, so
return (output_tensor, loss_func) instead of the raw tuple; reference the call
site model(**forward_args) and the returned loss_func when making this change.



def load_image(image_path: str) -> Image.Image:
Expand Down Expand Up @@ -616,11 +609,8 @@ def _load_megatron_model(args):
model_provider.finalize()
megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False)

# Workaround: disable MTP for inference (causes hangs on NCCL collectives)
for m in megatron_model:
m.config.mtp_num_layers = None
m.config.grad_scale_func = None

model_components = [m.eval() for m in megatron_model]

# Register debug hooks if enabled
Expand Down Expand Up @@ -727,27 +717,29 @@ def compare_models_one_step(args) -> None:
)

del hf_model
gc.collect()
torch.cuda.empty_cache()

# Broadcast HF results to all ranks
# Broadcast HF results to all ranks after Megatron initialization
# (following the pattern from generate_from_hf.py)
if torch.distributed.is_initialized():
# Create tensors for broadcasting if they don't exist on non-rank-0
# Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model,
# so all ranks must use the same dtype for NCCL broadcast to work correctly.
if hf_logits is not None:
hf_logits = hf_logits.float()

if hf_next_token is None:
hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long)
if hf_logits is None:
# Get vocab size from tokenizer for proper tensor size
vocab_size = getattr(
tokenizer, "vocab_size", len(tokenizer.vocab) if hasattr(tokenizer, "vocab") else 32000
)
hf_logits = torch.zeros(vocab_size, device=input_ids.device, dtype=torch.float32)

# Ensure consistent dtype across ranks before broadcast
hf_logits = hf_logits.float()

# Broadcast from rank 0 to all ranks
torch.distributed.broadcast(hf_next_token, 0)
torch.distributed.broadcast(hf_logits, 0)
torch.distributed.barrier()
print_rank_0("HF results broadcast complete.")

# Run Megatron model forward pass
print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===")
Expand Down Expand Up @@ -792,10 +784,7 @@ def compare_models_one_step(args) -> None:
megatron_logits = megatron_output[0, -1, :]
megatron_next_token = torch.argmax(megatron_logits, dim=-1)

if not torch.distributed.is_initialized() or (
parallel_state.get_tensor_model_parallel_rank() == 0
and parallel_state.get_expert_model_parallel_rank() == 0
):
if not torch.distributed.is_initialized() or (parallel_state.get_tensor_model_parallel_rank() == 0 and parallel_state.get_expert_model_parallel_rank() == 0):
print(f"Megatron output shape: {megatron_output.shape}")
print(f"Megatron logits stats - mean: {megatron_logits.mean():.4f}, std: {megatron_logits.std():.4f}")
print(
Expand All @@ -807,27 +796,26 @@ def compare_models_one_step(args) -> None:
top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids]
print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}")

# Compare outputs (only where we have valid Megatron results)
# Megatron may pad vocab_size for GPU kernel efficiency — truncate
# to the HF vocab size so logits are directly comparable.
hf_vocab_size = hf_logits.shape[0]
megatron_logits_cmp = megatron_logits[:hf_vocab_size]
megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1)

# Compare outputs
print("=== COMPARISON ===")
token_match = hf_next_token.item() == megatron_next_token.item()
token_match = hf_next_token.item() == megatron_next_token_cmp.item()
token_status_emoji = "✅" if token_match else "❌"
print(f"Token match: {token_match} {token_status_emoji}")

# Compare logits if shapes match
if hf_logits.shape == megatron_logits.shape:
diff = (hf_logits - megatron_logits).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(
f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)"
)
else:
print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}")
print("Cannot compare logits directly due to shape mismatch")
diff = (hf_logits - megatron_logits_cmp).abs()
print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}")
cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0))
cos_val = cosine_sim.item()
percent = cos_val * 100.0
status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌"
tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%"
print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)")
Comment on lines +799 to +818
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add an explicit vocab-size compatibility guard before truncation.

At Line 802, truncation assumes Megatron logits are at least HF vocab length. If Megatron vocab is smaller, Line 811 still fails later with a shape mismatch; fail fast with a clear error.

Proposed fix
                 hf_vocab_size = hf_logits.shape[0]
+                if megatron_logits.shape[0] < hf_vocab_size:
+                    raise ValueError(
+                        "Incompatible vocab sizes: "
+                        f"Megatron logits ({megatron_logits.shape[0]}) < HF logits ({hf_vocab_size}). "
+                        "Ensure both models use the same tokenizer/vocab."
+                    )
                 megatron_logits_cmp = megatron_logits[:hf_vocab_size]
                 megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1)

Based on learnings: when a path is unsupported, raise an explicit, descriptive error instead of failing later with an implicit runtime mismatch.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/conversion/compare_hf_and_megatron/compare.py` around lines 799 -
818, Add an explicit vocab-size compatibility guard before truncating
megatron_logits: compute hf_vocab_size = hf_logits.shape[0] then check
megatron_logits.size(0) >= hf_vocab_size (using megatron_logits.size(0) or
.shape[0]) and if not raise a descriptive ValueError (e.g., "Megatron logits
vocab smaller than HF vocab: megatron_vocab=..., hf_vocab=...") to fail fast;
keep the existing truncation into megatron_logits_cmp and subsequent comparisons
(hf_next_token, megatron_next_token_cmp, diff, cosine_similarity,
SIMILARITY_THRESHOLD) unchanged when the check passes.


print("=== COMPARISON COMPLETE ===")
else:
Expand Down
Loading