diff --git a/megatron/core/inference/inference_request.py b/megatron/core/inference/inference_request.py index 7111d11cb5a..473b26b81c3 100644 --- a/megatron/core/inference/inference_request.py +++ b/megatron/core/inference/inference_request.py @@ -32,6 +32,7 @@ class InferenceRequest: prompt: str sampling_params: Optional[SamplingParams] = None inference_parameters: Optional[SamplingParams] = None + prompt_logits: Optional[torch.Tensor] = None prompt_tokens: Optional[List[int]] = None arrival_time: Optional[float] = None status: Optional[Status] = None @@ -40,6 +41,7 @@ class InferenceRequest: segments: Optional[List[str]] = None generated_segments: Optional[List[str]] = None generated_sequence_lengths: Optional[List[int]] = None + generated_logits: Optional[torch.Tensor] = None generated_tokens: Optional[torch.Tensor] = None prompt_log_probs: Optional[torch.Tensor] = None generated_log_probs: Optional[torch.Tensor] = None diff --git a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py index 5367d0be1bb..4f0643fadf8 100644 --- a/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py +++ b/megatron/core/inference/model_inference_wrappers/abstract_model_inference_wrapper.py @@ -117,7 +117,7 @@ def prep_model_for_inference(self, prompts_tokens: Optional[torch.Tensor] = None is_pipeline_first_stage(self.pp_group) and is_pipeline_last_stage(self.pp_group) ) - self.inference_context.reset() + self.inference_context = type(self.inference_context).from_config(self.inference_wrapper_config) @abc.abstractmethod def prep_inference_input(self, prompt_tokens) -> Dict[str, Any]: @@ -152,11 +152,12 @@ def _forward(self, inference_input): tokens = inference_input["tokens"] position_ids = inference_input["position_ids"] attention_mask = inference_input["attention_mask"] + inference_context = inference_input.get("inference_context", self.inference_context) return self.model( tokens, position_ids, attention_mask, - inference_context=self.inference_context, + inference_context=inference_context, runtime_gather_output=True, # Inference should always gather the logits ) diff --git a/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py index 2276549c025..352a1fad9cb 100644 --- a/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py +++ b/megatron/core/inference/model_inference_wrappers/inference_wrapper_config.py @@ -31,6 +31,12 @@ class InferenceWrapperConfig: inference_max_seq_length: int = 2560 """ Maximum sequence length for inference (prefill & decode). Necessary for CUDA graphs. """ + prompt_segmentation_threshold: int | None = None + """If prompt length exceeds this value, it will be split into segments. This + feature allows to process very large prompts that normally would cause + Out Of Memory (OOM) during forward pass. If None, prompt segmentation is + disabled.""" + fp32_residual_connection: bool = False """Move residual connections to fp32. Obtained from arguments.py""" diff --git a/megatron/core/inference/sampling_params.py b/megatron/core/inference/sampling_params.py index 75e6adb0ef6..396a97b396f 100644 --- a/megatron/core/inference/sampling_params.py +++ b/megatron/core/inference/sampling_params.py @@ -1,4 +1,5 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from collections.abc import Callable from dataclasses import dataclass @@ -19,10 +20,12 @@ class SamplingParams: top_k: int = 0 top_p: float = 0.0 return_log_probs: bool = False + return_logits: bool = False return_segments: bool = False # Whether to return individually detokenized tokens num_tokens_to_generate: int = 30 top_n_logprobs: int = 0 return_prompt_top_n_logprobs: bool = False + token_callback: Callable[[int], None] | None = None def add_attributes(self, attribute_value_pair: dict): """Utility to add more attributes to sampling params diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index c8ce19b928f..06c85797c24 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -206,7 +206,6 @@ def sample_from_logits( assert isinstance(top_p, float) assert isinstance(top_k, int) - assert not (top_k > 0 and top_p > 0.0), "Cannot have top-p and top-k both greater than zero" assert top_p <= 1.0, "top-p should be in (0,1]" def modify_logits_for_top_k_filtering(logits, top_k): @@ -293,8 +292,7 @@ def modify_logits_for_top_p_filtering(logits, top_p): if vocab_size: assert top_k < vocab_size, "top-k is larger than vocab size." modify_logits_for_top_k_filtering(last_token_logits, top_k) - - elif top_p > 0.0: + if top_p > 0.0: modify_logits_for_top_p_filtering(last_token_logits, top_p) # After filtering, we need to recalculate the distribution. @@ -577,6 +575,9 @@ def generate_all_output_tokens_static_batch( max_prompt_length_in_batch = max(prompt_lengths_in_batch) min_prompt_length_in_batch = min(prompt_lengths_in_batch) + pst = self.inference_wrapped_model.inference_wrapper_config.prompt_segmentation_threshold + pst = min(pst or min_prompt_length_in_batch, min_prompt_length_in_batch) + # For batch inference the sampling params are the same for all request sampling_params: SamplingParams = list(active_requests.values())[0].sampling_params @@ -659,6 +660,15 @@ def generate_all_output_tokens_static_batch( # to nearest power of 2 vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size + # Pre allocate logits tensor + output_logits = None + if sampling_params.return_logits: + output_logits = torch.empty( + (batch_size, max_sequence_length - 1, vocab_size), + dtype=torch.float32, + device=torch.cuda.current_device(), + ) + # Check whether early termination is enabled no_early_termination = getattr(sampling_params, "no_early_termination", False) termination_id = -1 if no_early_termination else self.tokenizer.eod @@ -714,7 +724,7 @@ def generate_all_output_tokens_static_batch( if sampling_params.num_tokens_to_generate == 0: context_end_position = max_prompt_length_in_batch else: - context_end_position = min_prompt_length_in_batch + context_end_position = pst # The initial iteration of this loop runs the prefill phase up to the shortest # prompt length in the batch. Then every subsequent iterations runs a decode step. @@ -845,6 +855,14 @@ def generate_all_output_tokens_static_batch( log_probs, 2, indices ).squeeze(2) + if sampling_params.return_logits: + # Store the raw logits for the current context window + assert output_logits is not None + output_logits[:, context_start_position:context_end_position] = logits + + if sampling_params.token_callback: + sampling_params.token_callback(context_end_position) + context_start_position = context_end_position if sampling_params.num_tokens_to_generate > 0: @@ -911,6 +929,10 @@ def generate_all_output_tokens_static_batch( assert output_log_probs is not None output_log_probs = output_log_probs[:, :context_end_position] + if sampling_params.return_logits: + assert output_logits is not None + output_logits = output_logits[:, :context_end_position] + generated_sequence_lengths[ generated_sequence_lengths > sampling_params.num_tokens_to_generate ] = sampling_params.num_tokens_to_generate @@ -984,6 +1006,22 @@ def generate_all_output_tokens_static_batch( :required_sequence_length ] + request.prompt_logits = ( + None + if output_logits is None + else output_logits[idx, :input_prompt_length].cpu() + ) + + request.generated_logits = ( + None + if output_logits is None + else output_logits[ + idx, + input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1), + ] + .cpu() + ) + request.status = Status.COMPLETED text, segments = self.detokenize_generations( diff --git a/megatron/core/post_training/modelopt/layers.py b/megatron/core/post_training/modelopt/layers.py index 0ca4a8e4070..d74ebf2a3eb 100644 --- a/megatron/core/post_training/modelopt/layers.py +++ b/megatron/core/post_training/modelopt/layers.py @@ -114,6 +114,7 @@ def __init__( tp_comm_buffer_name: str = None, # Not used disable_grad_reduce: bool = False, tp_group: Optional[torch.distributed.ProcessGroup] = None, + input_is_parallel = None, # To make compatible with TELinear ): self.config = config @@ -162,7 +163,7 @@ def forward(self, x): out = super().forward(x) if self._return_bias: - return out + return out, self.bias.detach() return out, None diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 47ec1ff0626..f789c8228f1 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -672,7 +672,7 @@ def forward( # relative positional embedding (rotary embedding) # ================================================ nvtx_range_push(suffix="rotary_pos_emb") - if rotary_pos_emb is not None and not self.config.flash_decode: + if rotary_pos_emb is not None and (not self.config.flash_decode or inference_context is None): q_pos_emb, k_pos_emb = rotary_pos_emb if packed_seq_params is not None: diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 7da5396142e..87279dea7b2 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -311,7 +311,18 @@ def sh_ten_build_fn( def sh_ten_merge_fn(sub_state_dict): with torch.no_grad(): - return torch.cat(sub_state_dict) + n = len(sub_state_dict) + m = sub_state_dict[0].shape[0] + k = sub_state_dict[0].shape[1] + + # Merge everything into 0th tensor. + sub_state_dict[0].resize_([n*m, k]) + + for i in range(1, n): + sub_state_dict[0][i*m:, :] = sub_state_dict[i] + sub_state_dict[i].resize_([0, 0]) + + return sub_state_dict[0] return ShardedTensorFactory( original_sh_ten.key, diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index 68f0062c8dd..90d380006ff 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -101,14 +101,6 @@ def teardown_method(self, method): def test_sample_from_logits(self): self.setup_model(torch.float32) - with pytest.raises(AssertionError) as aerror: - self.text_generation_controller.sample_from_logits( - last_token_logits=None, - sampling_params=SamplingParams(top_k=2, top_p=0.4), - vocab_size=self.vocab_size, - ) - assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero' - with pytest.raises(AssertionError) as aerror: self.text_generation_controller.sample_from_logits( last_token_logits=None, @@ -186,6 +178,67 @@ def detokenize(self, inp, skip_special_tokens=False): sampled_logits >= expected_min_value ), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}" + # Test case for both top_k and top_p being greater than zero + # This should apply top_k filtering first, then top_p filtering + + # First, a general test for top_k + top_p as before + top_k = 10 + top_p = 0.7 + temperature = 1.5 + + test_logits = ( + torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda() + ) + + sampled_logits = self.text_generation_controller.sample_from_logits( + test_logits, + SamplingParams(top_k=top_k, top_p=top_p, temperature=temperature), + self.vocab_size, + ) + + assert torch.all( + sampled_logits >= self.vocab_size - top_k + ), f"With top_k={top_k}, sampled tokens should be from top {top_k} tokens, but got {sampled_logits}" + + # Now, a specific test to ensure top_p is working after top_k + # We'll set up logits so that top_k=2, and top_p will only allow the highest logit + + # Create logits: last two tokens are much higher than the rest + # For batch_size=1 for simplicity + batch_size = 1 + vocab_size = 6 + logits = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 10.0]], device="cuda") # shape (1, 6) + # After top_k=2, only indices 4 and 5 remain (values 1.0 and 10.0) + # After softmax with large temperature, probabilities will be close to uniform, but let's use a moderate temperature + # Let's compute softmax([1.0, 10.0]/T) for T=0.5 (sharper), T=10 (flatter) + # We'll use a large temperature to make the probabilities more uniform, so top_p can cut off the lower one + + top_k = 2 + temperature = 10.0 # Large temperature, so softmax([1.0, 10.0]/10) ~ softmax([0.1, 1.0]) ~ [0.289, 0.710] + # If we set top_p=0.7, only the highest logit (index 5) should remain after top_p filtering + + sampled_logits = self.text_generation_controller.sample_from_logits( + logits, + SamplingParams(top_k=top_k, top_p=0.7, temperature=temperature), + vocab_size, + ) + + # Only index 5 should be possible + assert torch.all( + sampled_logits == 5 + ), f"With top_k=2 and top_p=0.7, only the highest logit (index 5) should remain, but got {sampled_logits}" + + # If we set top_p=1.0, both tokens should be possible + sampled_indices = set() + for _ in range(20): + sampled = self.text_generation_controller.sample_from_logits( + logits, + SamplingParams(top_k=top_k, top_p=1.0, temperature=temperature), + vocab_size, + ) + sampled_indices.add(sampled.item()) + assert {4, 5}.issubset(sampled_indices), "With top_k=2 and top_p=1.0, both top tokens should be possible to sample" + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize( "symmetric_ar_type",