-
Notifications
You must be signed in to change notification settings - Fork 199
[misc] fix: Improve compare.py robustness for multi-GPU and vocab-padded models #2646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,7 +91,6 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| import gc | ||
| import importlib | ||
| import os | ||
| import sys | ||
|
|
@@ -319,13 +318,7 @@ def vlm_forward_step(data_iterator, model, **kwargs) -> torch.Tensor: | |
| def loss_func(x, **kwargs): | ||
| return x | ||
|
|
||
| model_output = model(**forward_args) | ||
| if isinstance(model_output, tuple): | ||
| output_tensor, _ = model_output | ||
| else: | ||
| output_tensor = model_output | ||
|
|
||
| return output_tensor, loss_func | ||
| return model(**forward_args), loss_func | ||
|
|
||
|
|
||
| def load_image(image_path: str) -> Image.Image: | ||
|
|
@@ -616,11 +609,8 @@ def _load_megatron_model(args): | |
| model_provider.finalize() | ||
| megatron_model = model_provider.provide_distributed_model(wrap_with_ddp=False) | ||
|
|
||
| # Workaround: disable MTP for inference (causes hangs on NCCL collectives) | ||
| for m in megatron_model: | ||
| m.config.mtp_num_layers = None | ||
| m.config.grad_scale_func = None | ||
|
|
||
| model_components = [m.eval() for m in megatron_model] | ||
|
|
||
| # Register debug hooks if enabled | ||
|
|
@@ -727,27 +717,29 @@ def compare_models_one_step(args) -> None: | |
| ) | ||
|
|
||
| del hf_model | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
|
|
||
| # Broadcast HF results to all ranks | ||
| # Broadcast HF results to all ranks after Megatron initialization | ||
| # (following the pattern from generate_from_hf.py) | ||
| if torch.distributed.is_initialized(): | ||
| # Create tensors for broadcasting if they don't exist on non-rank-0 | ||
| # Ensure consistent dtype across ranks: rank 0 has bfloat16 logits from the HF model, | ||
| # so all ranks must use the same dtype for NCCL broadcast to work correctly. | ||
| if hf_logits is not None: | ||
| hf_logits = hf_logits.float() | ||
|
|
||
| if hf_next_token is None: | ||
| hf_next_token = torch.zeros(1, device=input_ids.device, dtype=torch.long) | ||
| if hf_logits is None: | ||
| # Get vocab size from tokenizer for proper tensor size | ||
| vocab_size = getattr( | ||
| tokenizer, "vocab_size", len(tokenizer.vocab) if hasattr(tokenizer, "vocab") else 32000 | ||
| ) | ||
| hf_logits = torch.zeros(vocab_size, device=input_ids.device, dtype=torch.float32) | ||
|
|
||
| # Ensure consistent dtype across ranks before broadcast | ||
| hf_logits = hf_logits.float() | ||
|
|
||
| # Broadcast from rank 0 to all ranks | ||
| torch.distributed.broadcast(hf_next_token, 0) | ||
| torch.distributed.broadcast(hf_logits, 0) | ||
| torch.distributed.barrier() | ||
| print_rank_0("HF results broadcast complete.") | ||
|
|
||
| # Run Megatron model forward pass | ||
| print_rank_0("=== RUNNING MEGATRON MODEL (1-STEP) ===") | ||
|
|
@@ -792,10 +784,7 @@ def compare_models_one_step(args) -> None: | |
| megatron_logits = megatron_output[0, -1, :] | ||
| megatron_next_token = torch.argmax(megatron_logits, dim=-1) | ||
|
|
||
| if not torch.distributed.is_initialized() or ( | ||
| parallel_state.get_tensor_model_parallel_rank() == 0 | ||
| and parallel_state.get_expert_model_parallel_rank() == 0 | ||
| ): | ||
| if not torch.distributed.is_initialized() or (parallel_state.get_tensor_model_parallel_rank() == 0 and parallel_state.get_expert_model_parallel_rank() == 0): | ||
| print(f"Megatron output shape: {megatron_output.shape}") | ||
| print(f"Megatron logits stats - mean: {megatron_logits.mean():.4f}, std: {megatron_logits.std():.4f}") | ||
| print( | ||
|
|
@@ -807,27 +796,26 @@ def compare_models_one_step(args) -> None: | |
| top5_tokens = [tokenizer.decode([idx]) for idx in top5_ids] | ||
| print(f"Megatron Top 5: {list(zip(top5_tokens, top5_vals.tolist()))}") | ||
|
|
||
| # Compare outputs (only where we have valid Megatron results) | ||
| # Megatron may pad vocab_size for GPU kernel efficiency — truncate | ||
| # to the HF vocab size so logits are directly comparable. | ||
| hf_vocab_size = hf_logits.shape[0] | ||
| megatron_logits_cmp = megatron_logits[:hf_vocab_size] | ||
| megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1) | ||
|
|
||
| # Compare outputs | ||
| print("=== COMPARISON ===") | ||
| token_match = hf_next_token.item() == megatron_next_token.item() | ||
| token_match = hf_next_token.item() == megatron_next_token_cmp.item() | ||
| token_status_emoji = "✅" if token_match else "❌" | ||
| print(f"Token match: {token_match} {token_status_emoji}") | ||
|
|
||
| # Compare logits if shapes match | ||
| if hf_logits.shape == megatron_logits.shape: | ||
| diff = (hf_logits - megatron_logits).abs() | ||
| print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") | ||
| cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits.unsqueeze(0)) | ||
| cos_val = cosine_sim.item() | ||
| percent = cos_val * 100.0 | ||
| status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" | ||
| tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" | ||
| print( | ||
| f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)" | ||
| ) | ||
| else: | ||
| print(f"Shape mismatch: HF {hf_logits.shape} vs Megatron {megatron_logits.shape}") | ||
| print("Cannot compare logits directly due to shape mismatch") | ||
| diff = (hf_logits - megatron_logits_cmp).abs() | ||
| print(f"Logits diff - max: {diff.max():.6f}, mean: {diff.mean():.6f}") | ||
| cosine_sim = torch.cosine_similarity(hf_logits.unsqueeze(0), megatron_logits_cmp.unsqueeze(0)) | ||
| cos_val = cosine_sim.item() | ||
| percent = cos_val * 100.0 | ||
| status_emoji = "✅" if cos_val >= SIMILARITY_THRESHOLD else "❌" | ||
| tolerance_text = "within ±2%" if cos_val >= SIMILARITY_THRESHOLD else "outside ±2%" | ||
| print(f"Cosine similarity: {cos_val:.6f} ({percent:.2f}%) {status_emoji} ({tolerance_text} tolerance)") | ||
|
Comment on lines
+799
to
+818
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an explicit vocab-size compatibility guard before truncation. At Line 802, truncation assumes Megatron logits are at least HF vocab length. If Megatron vocab is smaller, Line 811 still fails later with a shape mismatch; fail fast with a clear error. Proposed fix hf_vocab_size = hf_logits.shape[0]
+ if megatron_logits.shape[0] < hf_vocab_size:
+ raise ValueError(
+ "Incompatible vocab sizes: "
+ f"Megatron logits ({megatron_logits.shape[0]}) < HF logits ({hf_vocab_size}). "
+ "Ensure both models use the same tokenizer/vocab."
+ )
megatron_logits_cmp = megatron_logits[:hf_vocab_size]
megatron_next_token_cmp = torch.argmax(megatron_logits_cmp, dim=-1)Based on learnings: when a path is unsupported, raise an explicit, descriptive error instead of failing later with an implicit runtime mismatch. 🤖 Prompt for AI Agents |
||
|
|
||
| print("=== COMPARISON COMPLETE ===") | ||
| else: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle tuple model outputs before returning from
vlm_forward_step.At Line 321, returning raw
model(**forward_args)can propagate a tuple output (e.g.,(output_tensor, loss_mask)), while downstream code assumes a tensor and will fail on tensor ops/indexing.Proposed fix
🤖 Prompt for AI Agents