Skip to content
2 changes: 2 additions & 0 deletions megatron/core/inference/inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class InferenceRequest:
prompt: str
sampling_params: Optional[SamplingParams] = None
inference_parameters: Optional[SamplingParams] = None
prompt_logits: Optional[torch.Tensor] = None
prompt_tokens: Optional[List[int]] = None
arrival_time: Optional[float] = None
status: Optional[Status] = None
Expand All @@ -40,6 +41,7 @@ class InferenceRequest:
segments: Optional[List[str]] = None
generated_segments: Optional[List[str]] = None
generated_sequence_lengths: Optional[List[int]] = None
generated_logits: Optional[torch.Tensor] = None
generated_tokens: Optional[torch.Tensor] = None
prompt_log_probs: Optional[torch.Tensor] = None
generated_log_probs: Optional[torch.Tensor] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def prep_model_for_inference(self, prompts_tokens: Optional[torch.Tensor] = None
is_pipeline_first_stage(self.pp_group) and is_pipeline_last_stage(self.pp_group)
)

self.inference_context.reset()
self.inference_context = type(self.inference_context).from_config(self.inference_wrapper_config)

@abc.abstractmethod
def prep_inference_input(self, prompt_tokens) -> Dict[str, Any]:
Expand Down Expand Up @@ -152,11 +152,12 @@ def _forward(self, inference_input):
tokens = inference_input["tokens"]
position_ids = inference_input["position_ids"]
attention_mask = inference_input["attention_mask"]
inference_context = inference_input.get("inference_context", self.inference_context)
return self.model(
tokens,
position_ids,
attention_mask,
inference_context=self.inference_context,
inference_context=inference_context,
runtime_gather_output=True, # Inference should always gather the logits
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class InferenceWrapperConfig:
inference_max_seq_length: int = 2560
""" Maximum sequence length for inference (prefill & decode). Necessary for CUDA graphs. """

prompt_segmentation_threshold: int | None = None
"""If prompt length exceeds this value, it will be split into segments. This
feature allows to process very large prompts that normally would cause
Out Of Memory (OOM) during forward pass. If None, prompt segmentation is
disabled."""

fp32_residual_connection: bool = False
"""Move residual connections to fp32. Obtained from arguments.py"""

Expand Down
3 changes: 3 additions & 0 deletions megatron/core/inference/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from collections.abc import Callable
from dataclasses import dataclass


Expand All @@ -19,10 +20,12 @@ class SamplingParams:
top_k: int = 0
top_p: float = 0.0
return_log_probs: bool = False
return_logits: bool = False
return_segments: bool = False # Whether to return individually detokenized tokens
num_tokens_to_generate: int = 30
top_n_logprobs: int = 0
return_prompt_top_n_logprobs: bool = False
token_callback: Callable[[int], None] | None = None

def add_attributes(self, attribute_value_pair: dict):
"""Utility to add more attributes to sampling params
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ def sample_from_logits(

assert isinstance(top_p, float)
assert isinstance(top_k, int)
assert not (top_k > 0 and top_p > 0.0), "Cannot have top-p and top-k both greater than zero"
assert top_p <= 1.0, "top-p should be in (0,1]"

def modify_logits_for_top_k_filtering(logits, top_k):
Expand Down Expand Up @@ -293,8 +292,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
if vocab_size:
assert top_k < vocab_size, "top-k is larger than vocab size."
modify_logits_for_top_k_filtering(last_token_logits, top_k)

elif top_p > 0.0:
if top_p > 0.0:
modify_logits_for_top_p_filtering(last_token_logits, top_p)

# After filtering, we need to recalculate the distribution.
Expand Down Expand Up @@ -577,6 +575,9 @@ def generate_all_output_tokens_static_batch(
max_prompt_length_in_batch = max(prompt_lengths_in_batch)
min_prompt_length_in_batch = min(prompt_lengths_in_batch)

pst = self.inference_wrapped_model.inference_wrapper_config.prompt_segmentation_threshold
pst = min(pst or min_prompt_length_in_batch, min_prompt_length_in_batch)

# For batch inference the sampling params are the same for all request
sampling_params: SamplingParams = list(active_requests.values())[0].sampling_params

Expand Down Expand Up @@ -659,6 +660,15 @@ def generate_all_output_tokens_static_batch(
# to nearest power of 2
vocab_size = self.inference_wrapped_model.inference_wrapper_config.padded_vocab_size

# Pre allocate logits tensor
output_logits = None
if sampling_params.return_logits:
output_logits = torch.empty(
(batch_size, max_sequence_length - 1, vocab_size),
dtype=torch.float32,
device=torch.cuda.current_device(),
)

# Check whether early termination is enabled
no_early_termination = getattr(sampling_params, "no_early_termination", False)
termination_id = -1 if no_early_termination else self.tokenizer.eod
Expand Down Expand Up @@ -714,7 +724,7 @@ def generate_all_output_tokens_static_batch(
if sampling_params.num_tokens_to_generate == 0:
context_end_position = max_prompt_length_in_batch
else:
context_end_position = min_prompt_length_in_batch
context_end_position = pst

# The initial iteration of this loop runs the prefill phase up to the shortest
# prompt length in the batch. Then every subsequent iterations runs a decode step.
Expand Down Expand Up @@ -845,6 +855,14 @@ def generate_all_output_tokens_static_batch(
log_probs, 2, indices
).squeeze(2)

if sampling_params.return_logits:
# Store the raw logits for the current context window
assert output_logits is not None
output_logits[:, context_start_position:context_end_position] = logits

if sampling_params.token_callback:
sampling_params.token_callback(context_end_position)

context_start_position = context_end_position

if sampling_params.num_tokens_to_generate > 0:
Expand Down Expand Up @@ -911,6 +929,10 @@ def generate_all_output_tokens_static_batch(
assert output_log_probs is not None
output_log_probs = output_log_probs[:, :context_end_position]

if sampling_params.return_logits:
assert output_logits is not None
output_logits = output_logits[:, :context_end_position]

generated_sequence_lengths[
generated_sequence_lengths > sampling_params.num_tokens_to_generate
] = sampling_params.num_tokens_to_generate
Expand Down Expand Up @@ -984,6 +1006,22 @@ def generate_all_output_tokens_static_batch(
:required_sequence_length
]

request.prompt_logits = (
None
if output_logits is None
else output_logits[idx, :input_prompt_length].cpu()
)

request.generated_logits = (
None
if output_logits is None
else output_logits[
idx,
input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1),
]
.cpu()
)

request.status = Status.COMPLETED

text, segments = self.detokenize_generations(
Expand Down
3 changes: 2 additions & 1 deletion megatron/core/post_training/modelopt/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
tp_comm_buffer_name: str = None, # Not used
disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
input_is_parallel = None, # To make compatible with TELinear
):
self.config = config

Expand Down Expand Up @@ -162,7 +163,7 @@ def forward(self, x):
out = super().forward(x)

if self._return_bias:
return out
return out, self.bias.detach()
return out, None


Expand Down
2 changes: 1 addition & 1 deletion megatron/core/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ def forward(
# relative positional embedding (rotary embedding)
# ================================================
nvtx_range_push(suffix="rotary_pos_emb")
if rotary_pos_emb is not None and not self.config.flash_decode:
if rotary_pos_emb is not None and (not self.config.flash_decode or inference_context is None):
q_pos_emb, k_pos_emb = rotary_pos_emb

if packed_seq_params is not None:
Expand Down
13 changes: 12 additions & 1 deletion megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,18 @@ def sh_ten_build_fn(

def sh_ten_merge_fn(sub_state_dict):
with torch.no_grad():
return torch.cat(sub_state_dict)
n = len(sub_state_dict)
m = sub_state_dict[0].shape[0]
k = sub_state_dict[0].shape[1]

# Merge everything into 0th tensor.
sub_state_dict[0].resize_([n*m, k])

for i in range(1, n):
sub_state_dict[0][i*m:, :] = sub_state_dict[i]
sub_state_dict[i].resize_([0, 0])

return sub_state_dict[0]

return ShardedTensorFactory(
original_sh_ten.key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ def teardown_method(self, method):
def test_sample_from_logits(self):
self.setup_model(torch.float32)

with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=None,
sampling_params=SamplingParams(top_k=2, top_p=0.4),
vocab_size=self.vocab_size,
)
assert str(aerror.value) == 'Cannot have top-p and top-k both greater than zero'

with pytest.raises(AssertionError) as aerror:
self.text_generation_controller.sample_from_logits(
last_token_logits=None,
Expand Down Expand Up @@ -186,6 +178,67 @@ def detokenize(self, inp, skip_special_tokens=False):
sampled_logits >= expected_min_value
), f"The sampled logits should all be greater than {expected_min_value} but its {sampled_logits}"

# Test case for both top_k and top_p being greater than zero
# This should apply top_k filtering first, then top_p filtering

# First, a general test for top_k + top_p as before
top_k = 10
top_p = 0.7
temperature = 1.5

test_logits = (
torch.arange(0, self.vocab_size).repeat(self.batch_size, 1).float().cuda()
)

sampled_logits = self.text_generation_controller.sample_from_logits(
test_logits,
SamplingParams(top_k=top_k, top_p=top_p, temperature=temperature),
self.vocab_size,
)

assert torch.all(
sampled_logits >= self.vocab_size - top_k
), f"With top_k={top_k}, sampled tokens should be from top {top_k} tokens, but got {sampled_logits}"

# Now, a specific test to ensure top_p is working after top_k
# We'll set up logits so that top_k=2, and top_p will only allow the highest logit

# Create logits: last two tokens are much higher than the rest
# For batch_size=1 for simplicity
batch_size = 1
vocab_size = 6
logits = torch.tensor([[0.0, 0.0, 0.0, 0.0, 1.0, 10.0]], device="cuda") # shape (1, 6)
# After top_k=2, only indices 4 and 5 remain (values 1.0 and 10.0)
# After softmax with large temperature, probabilities will be close to uniform, but let's use a moderate temperature
# Let's compute softmax([1.0, 10.0]/T) for T=0.5 (sharper), T=10 (flatter)
# We'll use a large temperature to make the probabilities more uniform, so top_p can cut off the lower one

top_k = 2
temperature = 10.0 # Large temperature, so softmax([1.0, 10.0]/10) ~ softmax([0.1, 1.0]) ~ [0.289, 0.710]
# If we set top_p=0.7, only the highest logit (index 5) should remain after top_p filtering

sampled_logits = self.text_generation_controller.sample_from_logits(
logits,
SamplingParams(top_k=top_k, top_p=0.7, temperature=temperature),
vocab_size,
)

# Only index 5 should be possible
assert torch.all(
sampled_logits == 5
), f"With top_k=2 and top_p=0.7, only the highest logit (index 5) should remain, but got {sampled_logits}"

# If we set top_p=1.0, both tokens should be possible
sampled_indices = set()
for _ in range(20):
sampled = self.text_generation_controller.sample_from_logits(
logits,
SamplingParams(top_k=top_k, top_p=1.0, temperature=temperature),
vocab_size,
)
sampled_indices.add(sampled.item())
assert {4, 5}.issubset(sampled_indices), "With top_k=2 and top_p=1.0, both top tokens should be possible to sample"

@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize(
"symmetric_ar_type",
Expand Down