From 35ca667e8d6023eb21ddcce754661caa605dd14c Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Wed, 4 Mar 2026 10:06:01 -0700 Subject: [PATCH] [misc] fix: Improve compare.py robustness for multi-GPU and vocab-padded models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix NCCL broadcast dtype mismatch by converting logits to float32 before fallback tensor creation - Handle Megatron vocab-size padding by truncating logits to HF vocab size before comparison, removing shape-mismatch branch - Simplify vlm_forward_step return (let caller handle tuple unpacking) - Remove unused gc import and grad_scale_func workaround - Add barrier after HF broadcast for synchronization safety - Simplify rank-0 guard to only check TP rank (EP rank check unnecessary) Verified on Qwen/Qwen3-0.6B: token match ✅, cosine similarity 99.99% Signed-off-by: yaoyu-33 Made-with: Cursor --- .../compare_hf_and_megatron/compare.py | 46 +++++++++++-------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/examples/conversion/compare_hf_and_megatron/compare.py b/examples/conversion/compare_hf_and_megatron/compare.py index 60ce377cd1..7657ec1af8 100644 --- a/examples/conversion/compare_hf_and_megatron/compare.py +++ b/examples/conversion/compare_hf_and_megatron/compare.py @@ -609,6 +609,8 @@ def _load_megatron_model(args): model_provider.finalize() megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False) + for m in megatron_model: + m.config.mtp_num_layers = None model_components = [m.eval() for m in megatron_model] # Register debug hooks if enabled @@ -715,17 +717,22 @@ def compare_models_one_step(args) -> None: ) del hf_model + torch.cuda.empty_cache() + # Reload Megatron model to ensure a fresh instance before comparison megatron_model, _ = _load_megatron_model(args) # Broadcast HF results to all ranks after Megatron initialization # (following the pattern from generate_from_hf.py) if torch.distributed.is_initialized(): - # Create tensors for broadcasting if they don't exist on non-rank-0 + # Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model, + # so all ranks must use the same dtype for NCCL broadcast to work correctly. + if hf_logits is not None: + hf_logits = hf_logits.float() + if hf_next_token is None: hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long) if hf_logits is None: - # Get vocab size from tokenizer for proper tensor size vocab_size = getattr( tokenizer, "vocab_size", len(tokenizer.vocab) if hasattr(tokenizer, "vocab") else 32000 ) @@ -734,6 +741,8 @@ def compare_models_one_step(args) -> None: # Broadcast from rank 0 to all ranks torch.distributed.broadcast(hf_next_token, 0) torch.distributed.broadcast(hf_logits, 0) + torch.distributed.barrier() + print_rank_0("HF results broadcast complete.") # Run Megatron model forward pass print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===") @@ -790,27 +799,26 @@ def compare_models_one_step(args) -> None: top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids] print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}") - # Compare outputs (only where we have valid Megatron results) + # Megatron may pad vocab_size for GPU kernel efficiency — truncate + # to the HF vocab size so logits are directly comparable. + hf_vocab_size = hf_logits.shape[0] + megatron_logits_cmp = megatron_logits[:hf_vocab_size] + megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1) + + # Compare outputs print("=== COMPARISON ===") - token_match = hf_next_token.item() == megatron_next_token.item() + token_match = hf_next_token.item() == megatron_next_token_cmp.item() token_status_emoji = "✅" if token_match else "❌" print(f"Token match: {token_match} {token_status_emoji}") - # Compare logits if shapes match - if hf_logits.shape == megatron_logits.shape: - diff = (hf_logits - megatron_logits).abs() - print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") - cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0)) - cos_val = cosine_sim.item() - percent = cos_val * 100.0 - status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" - tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" - print( - f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)" - ) - else: - print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}") - print("Cannot compare logits directly due to shape mismatch") + diff = (hf_logits - megatron_logits_cmp).abs() + print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") + cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0)) + cos_val = cosine_sim.item() + percent = cos_val * 100.0 + status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" + tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" + print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)") print("=== COMPARISON COMPLETE ===") else: