From 0a6963e46c5f1d28d18f73acdeb3d1cba6c256c0 Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Tue, 17 Jun 2025 18:14:50 -0700 Subject: [PATCH 1/9] Megatron-LM: evo2: More efficient checkpoint loading This is needed to make Evo 2 40b work on A6000 Ada x2. --- megatron/core/transformer/mlp.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 7da5396142e..87279dea7b2 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -311,7 +311,18 @@ def sh_ten_build_fn( def sh_ten_merge_fn(sub_state_dict): with torch.no_grad(): - return torch.cat(sub_state_dict) + n = len(sub_state_dict) + m = sub_state_dict[0].shape[0] + k = sub_state_dict[0].shape[1] + + # Merge everything into 0th tensor. + sub_state_dict[0].resize_([n*m, k]) + + for i in range(1, n): + sub_state_dict[0][i*m:, :] = sub_state_dict[i] + sub_state_dict[i].resize_([0, 0]) + + return sub_state_dict[0] return ShardedTensorFactory( original_sh_ten.key, From 63ac0500cffbcb2f3dad711badce9a63c3ebdf69 Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Wed, 25 Jun 2025 18:38:42 -0700 Subject: [PATCH 2/9] Megatron-LM: SamplingParams: Add token_callback --- megatron/core/inference/sampling_params.py | 2 ++ .../text_generation_controllers/text_generation_controller.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py index 75e6adb0ef6..cc8e2e4344a 100644 --- a/megatron/core/inference/sampling_params.py +++ b/megatron/core/inference/sampling_params.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from collections.abc import Callable from dataclasses import dataclass @@ -23,6 +24,7 @@ class SamplingParams: num_tokens_to_generate: int = 30 top_n_logprobs: int = 0 return_prompt_top_n_logprobs: bool = False + token_callback: Callable[[int], None] | None = None def add_attributes(self, attribute_value_pair: dict): """Utility to add more attributes to sampling params diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index c8ce19b928f..d75adedd46e 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -845,6 +845,9 @@ def generate_all_output_tokens_static_batch( log_probs, 2, indices ).squeeze(2) + if sampling_params.token_callback: + sampling_params.token_callback(context_end_position) + context_start_position = context_end_position if sampling_params.num_tokens_to_generate > 0: From c94a6e16f4c2e7adb14b809e4ce4644d4336e163 Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Wed, 16 Jul 2025 17:47:04 -0700 Subject: [PATCH 3/9] Megatron-LM: AbstractModelInferenceWrapper: Add custom inference_params Needed to pass None as inference params when we do cache-less forward pass. --- .../abstract_model_inference_wrapper.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 5367d0be1bb..684220bafb8 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -152,11 +152,12 @@ def _forward(self, inference_input): tokens = inference_input["tokens"] position_ids = inference_input["position_ids"] attention_mask = inference_input["attention_mask"] + inference_context = inference_input.get("inference_context", self.inference_context) return self.model( tokens, position_ids, attention_mask, - inference_context=self.inference_context, + inference_context=inference_context, runtime_gather_output=True, # Inference should always gather the logits ) From 76993757b83d56adf4bd79c10d7a4ae892fbb89f Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Tue, 22 Jul 2025 20:19:35 -0700 Subject: [PATCH 4/9] Megatron: Modelopt: Make Linear layer compatible with TELinear --- megatron/core/post_training/modelopt/layers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/post_training/modelopt/layers.py b/megatron/core/post_training/modelopt/layers.py index 0ca4a8e4070..d74ebf2a3eb 100644 --- a/megatron/core/post_training/modelopt/layers.py +++ b/megatron/core/post_training/modelopt/layers.py @@ -114,6 +114,7 @@ def __init__( tp_comm_buffer_name: str = None, # Not used disable_grad_reduce: bool = False, tp_group: Optional[torch.distributed.ProcessGroup] = None, + input_is_parallel = None, # To make compatible with TELinear ): self.config = config @@ -162,7 +163,7 @@ def forward(self, x): out = super().forward(x) if self._return_bias: - return out + return out, self.bias.detach() return out, None From 808e85ff648cb9e35b775c3f028bfd43c12a7116 Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Tue, 29 Jul 2025 14:48:02 -0700 Subject: [PATCH 5/9] Megatron-LM: Make flash_decode=True + inference_context=None work --- megatron/core/transformer/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 47ec1ff0626..f789c8228f1 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -672,7 +672,7 @@ def forward( # relative positional embedding (rotary embedding) # ================================================ nvtx_range_push(suffix="rotary_pos_emb") - if rotary_pos_emb is not None and not self.config.flash_decode: + if rotary_pos_emb is not None and (not self.config.flash_decode or inference_context is None): q_pos_emb, k_pos_emb = rotary_pos_emb if packed_seq_params is not None: From 0db044f277ab541cca17d15237e8fa4625e5dfd4 Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Tue, 29 Jul 2025 17:55:16 -0700 Subject: [PATCH 6/9] Megatron-LM: Add prompt_segmentation_threshold If prompt length exceeds this value, it will be split into segments. This feature allows to process very large prompts that normally would cause Out Of Memory (OOM) during forward pass. Here's how it works. When the input prompt length exceeds this threshold, the generation process is split into three phases: 1. One large forward pass of input tokens up to the threshold value. 2. The rest of the prompt that exceed the threshold are processed token-by-token without sampling. This operation executes at the token generation speed (throughput) as shown. 3. Regular generation, where after the input prompt is fully processed, normal token generation with sampling resumes. --- .../model_inference_wrappers/inference_wrapper_config.py | 6 ++++++ .../text_generation_controller.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py index 2276549c025..352a1fad9cb 100644 --- a/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +++ b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py @@ -31,6 +31,12 @@ class InferenceWrapperConfig: inference_max_seq_length: int = 2560 """ Maximum sequence length for inference (prefill & decode). Necessary for CUDA graphs. """ + prompt_segmentation_threshold: int | None = None + """If prompt length exceeds this value, it will be split into segments. This + feature allows to process very large prompts that normally would cause + Out Of Memory (OOM) during forward pass. If None, prompt segmentation is + disabled.""" + fp32_residual_connection: bool = False """Move residual connections to fp32. Obtained from arguments.py""" diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index d75adedd46e..f204dcd9773 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -577,6 +577,9 @@ def generate_all_output_tokens_static_batch( max_prompt_length_in_batch = max(prompt_lengths_in_batch) min_prompt_length_in_batch = min(prompt_lengths_in_batch) + pst = self.inference_wrapped_model.inference_wrapper_config.prompt_segmentation_threshold + pst = min(pst or min_prompt_length_in_batch, min_prompt_length_in_batch) + # For batch inference the sampling params are the same for all request sampling_params: SamplingParams = list(active_requests.values())[0].sampling_params @@ -714,7 +717,7 @@ def generate_all_output_tokens_static_batch( if sampling_params.num_tokens_to_generate == 0: context_end_position = max_prompt_length_in_batch else: - context_end_position = min_prompt_length_in_batch + context_end_position = pst # The initial iteration of this loop runs the prefill phase up to the shortest # prompt length in the batch. Then every subsequent iterations runs a decode step. From 391e174997db6d6b061fedbd5d855516140ceeee Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Tue, 29 Jul 2025 18:13:54 -0700 Subject: [PATCH 7/9] Megatron-LM: Add logits reporting to generate() API Logits reporting are required for Evo 2 NIM. --- megatron/core/inference/inference_request.py | 2 ++ megatron/core/inference/sampling_params.py | 1 + .../text_generation_controller.py | 34 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py index 7111d11cb5a..473b26b81c3 100644 --- a/megatron/core/inference/inference_request.py +++ b/megatron/core/inference/inference_request.py @@ -32,6 +32,7 @@ class InferenceRequest: prompt: str sampling_params: Optional[SamplingParams] = None inference_parameters: Optional[SamplingParams] = None + prompt_logits: Optional[torch.Tensor] = None prompt_tokens: Optional[List[int]] = None arrival_time: Optional[float] = None status: Optional[Status] = None @@ -40,6 +41,7 @@ class InferenceRequest: segments: Optional[List[str]] = None generated_segments: Optional[List[str]] = None generated_sequence_lengths: Optional[List[int]] = None + generated_logits: Optional[torch.Tensor] = None generated_tokens: Optional[torch.Tensor] = None prompt_log_probs: Optional[torch.Tensor] = None generated_log_probs: Optional[torch.Tensor] = None diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py index cc8e2e4344a..396a97b396f 100644 --- a/megatron/core/inference/sampling_params.py +++ b/megatron/core/inference/sampling_params.py @@ -20,6 +20,7 @@ class SamplingParams: top_k: int = 0 top_p: float = 0.0 return_log_probs: bool = False + return_logits: bool = False return_segments: bool = False # Whether to return individually detokenized tokens num_tokens_to_generate: int = 30 top_n_logprobs: int = 0 diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index f204dcd9773..94fb47f15e5 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -662,6 +662,15 @@ def generate_all_output_tokens_static_batch( # to nearest power of 2 vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size + # Pre allocate logits tensor + output_logits = None + if sampling_params.return_logits: + output_logits = torch.empty( + (batch_size, max_sequence_length - 1, vocab_size), + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + # Check whether early termination is enabled no_early_termination = getattr(sampling_params, "no_early_termination", False) termination_id = -1 if no_early_termination else self.tokenizer.eod @@ -848,6 +857,11 @@ def generate_all_output_tokens_static_batch( log_probs, 2, indices ).squeeze(2) + if sampling_params.return_logits: + # Store the raw logits for the current context window + assert output_logits is not None + output_logits[:, context_start_position:context_end_position] = logits + if sampling_params.token_callback: sampling_params.token_callback(context_end_position) @@ -917,6 +931,10 @@ def generate_all_output_tokens_static_batch( assert output_log_probs is not None output_log_probs = output_log_probs[:, :context_end_position] + if sampling_params.return_logits: + assert output_logits is not None + output_logits = output_logits[:, :context_end_position] + generated_sequence_lengths[ generated_sequence_lengths > sampling_params.num_tokens_to_generate ] = sampling_params.num_tokens_to_generate @@ -990,6 +1008,22 @@ def generate_all_output_tokens_static_batch( :required_sequence_length ] + request.prompt_logits = ( + None + if output_logits is None + else output_logits[idx, :input_prompt_length].cpu() + ) + + request.generated_logits = ( + None + if output_logits is None + else output_logits[ + idx, + input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1), + ] + .cpu() + ) + request.status = Status.COMPLETED text, segments = self.detokenize_generations( From 358fa0da91a4d97e27be06f11706e0ef79cc242e Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Fri, 1 Aug 2025 03:41:30 -0700 Subject: [PATCH 8/9] Megatron-LM: Reset inference context more safely --- .../abstract_model_inference_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 684220bafb8..4f0643fadf8 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -117,7 +117,7 @@ def prep_model_for_inference(self, prompts_tokens: Optional[torch.Tensor] = None is_pipeline_first_stage(self.pp_group) and is_pipeline_last_stage(self.pp_group) ) - self.inference_context.reset() + self.inference_context = type(self.inference_context).from_config(self.inference_wrapper_config) @abc.abstractmethod def prep_inference_input(self, prompt_tokens) -> Dict[str, Any]: From a75c9571a6bca5fbd835a0c6291aa6a40b28752b Mon Sep 17 00:00:00 2001 From: Anton Vorontsov Date: Wed, 13 Aug 2025 15:34:13 -0700 Subject: [PATCH 9/9] Megatron-LM: Allow top_p sub-sampling within top_k This is to make it compatible with Vortex: https://github.com/Zymrael/vortex/blob/debd9d160476b2498494507ffec0a697d3075a2d/vortex/model/sample.py#L51 --- .../text_generation_controller.py | 4 +- .../test_simple_text_generation_controller.py | 69 ++++++++++++++++--- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 94fb47f15e5..06c85797c24 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -206,7 +206,6 @@ def sample_from_logits( assert isinstance(top_p, float) assert isinstance(top_k, int) - assert not (top_k > 0 and top_p > 0.0), "Cannot have top-p and top-k both greater than zero" assert top_p <= 1.0, "top-p should be in (0,1]" def modify_logits_for_top_k_filtering(logits, top_k): @@ -293,8 +292,7 @@ def modify_logits_for_top_p_filtering(logits, top_p): if vocab_size: assert top_k < vocab_size, "top-k is larger than vocab size." modify_logits_for_top_k_filtering(last_token_logits, top_k) - - elif top_p > 0.0: + if top_p > 0.0: modify_logits_for_top_p_filtering(last_token_logits, top_p) # After filtering, we need to recalculate the distribution. diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index 68f0062c8dd..90d380006ff 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -101,14 +101,6 @@ def teardown_method(self, method): def test_sample_from_logits(self): self.setup_model(torch.float32) - with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits( - last_token_logits=None, - sampling_params=SamplingParams(top_k=2, top_p=0.4), - vocab_size=self.vocab_size, - ) - assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero' - with pytest.raises(AssertionError) as aerror: self.text_generation_controller.sample_from_logits( last_token_logits=None, @@ -186,6 +178,67 @@ def detokenize(self, inp, skip_special_tokens=False): sampled_logits >= expected_min_value ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" + # Test case for both top_k and top_p being greater than zero + # This should apply top_k filtering first, then top_p filtering + + # First, a general test for top_k + top_p as before + top_k = 10 + top_p = 0.7 + temperature = 1.5 + + test_logits = ( + torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda() + ) + + sampled_logits = self.text_generation_controller.sample_from_logits( + test_logits, + SamplingParams(top_k=top_k, top_p=top_p, temperature=temperature), + self.vocab_size, + ) + + assert torch.all( + sampled_logits >= self.vocab_size - top_k + ), f"With top_k={top_k}, sampled tokens should be from top {top_k} tokens, but got {sampled_logits}" + + # Now, a specific test to ensure top_p is working after top_k + # We'll set up logits so that top_k=2, and top_p will only allow the highest logit + + # Create logits: last two tokens are much higher than the rest + # For batch_size=1 for simplicity + batch_size = 1 + vocab_size = 6 + logits = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 10.0]], device="cuda") # shape (1, 6) + # After top_k=2, only indices 4 and 5 remain (values 1.0 and 10.0) + # After softmax with large temperature, probabilities will be close to uniform, but let's use a moderate temperature + # Let's compute softmax([1.0, 10.0]/T) for T=0.5 (sharper), T=10 (flatter) + # We'll use a large temperature to make the probabilities more uniform, so top_p can cut off the lower one + + top_k = 2 + temperature = 10.0 # Large temperature, so softmax([1.0, 10.0]/10) ~ softmax([0.1, 1.0]) ~ [0.289, 0.710] + # If we set top_p=0.7, only the highest logit (index 5) should remain after top_p filtering + + sampled_logits = self.text_generation_controller.sample_from_logits( + logits, + SamplingParams(top_k=top_k, top_p=0.7, temperature=temperature), + vocab_size, + ) + + # Only index 5 should be possible + assert torch.all( + sampled_logits == 5 + ), f"With top_k=2 and top_p=0.7, only the highest logit (index 5) should remain, but got {sampled_logits}" + + # If we set top_p=1.0, both tokens should be possible + sampled_indices = set() + for _ in range(20): + sampled = self.text_generation_controller.sample_from_logits( + logits, + SamplingParams(top_k=top_k, top_p=1.0, temperature=temperature), + vocab_size, + ) + sampled_indices.add(sampled.item()) + assert {4, 5}.issubset(sampled_indices), "With top_k=2 and top_p=1.0, both top tokens should be possible to sample" + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize( "symmetric_ar_type",