Skip to content

Conversation

@xiuyuz
Copy link

@xiuyuz xiuyuz commented Dec 22, 2025

Fix Incorrect Hidden State Extraction with Right Padding

Summary

This PR fixes a bug in models.py where the model was incorrectly extracting hidden states from padding tokens when using right-padded batches.

Bug Description

In generate_latent_batch and generate_latent_batch_hidden_state, the code used [:, -1, :] to extract the hidden states of the last token in the sequence:

e_t = outputs.hidden_states[0][:, -1, :]
last_hidden = outputs.hidden_states[-1][:, -1, :]

When using batch generation with right padding (e.g., [token_A, token_B, PAD]), index -1 corresponds to the PAD token. As a result, the latent reasoning steps were being initialized with the hidden state of the padding token rather than the last actual token of the prompt.

Fix

The fix uses the attention_mask to determine the index of the last non-padding token for each sequence in the batch:

last_token_indices = attention_mask.sum(1) - 1
last_hidden = outputs.hidden_states[-1][torch.arange(batch_size), last_token_indices, :]

This ensures that last_hidden (and e_t) correctly corresponds to the last real token (e.g., token_B), regardless of the padding.

@xiuyuz
Copy link
Author

xiuyuz commented Dec 22, 2025

This might be relevant to #25 discovered by @wonjun-chung.

@jiaruzouu
Copy link
Member

jiaruzouu commented Dec 26, 2025

Hi @xiuyuz ,

Thanks for your great contribution to our LatentMAS work! We will shortly review the code and merge it properly :) We will also ensure to mention your extension to our work later in the README.

Thanks again!

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a bug where hidden states were incorrectly extracted from padding tokens when using right-padded batches in latent reasoning steps. The fix uses the attention mask to identify the last non-padding token for each sequence in the batch, ensuring correct initialization of latent reasoning steps.

Key Changes:

  • Modified generate_latent_batch to use attention mask-based indexing to find the last real token instead of always using [:, -1, :]
  • Modified generate_latent_batch_hidden_state with the same padding-aware logic
  • Added conditional logic to only apply the fix when past_key_values is None (initial call with potential padding)

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +398 to +402
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
last_hidden = outputs.hidden_states[-1][:, -1, :]
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation inside this else block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Line 402 should be indented with 12 spaces (8 base + 4 for the else block) rather than 13 spaces.

Suggested change
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
last_hidden = outputs.hidden_states[-1][:, -1, :]
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
last_hidden = outputs.hidden_states[-1][:, -1, :]

Copilot uses AI. Check for mistakes.
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity and consistency with how attention_mask is handled, batch_indices should be created using self.device instead of input_ids.device. While input_ids.device should be the same as self.device (otherwise the model call would fail), using self.device explicitly makes the device management more clear and matches the pattern used for attention_mask at line 289.

Suggested change
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
batch_indices = torch.arange(input_ids.shape[0], device=self.device)

Copilot uses AI. Check for mistakes.

if past_key_values is None:
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity and consistency with how attention_mask is handled, batch_indices should be created using self.HF_device instead of input_ids.device. While input_ids.device should be the same as self.HF_device (otherwise the model call would fail), using self.HF_device explicitly makes the device management more clear and matches the pattern used for attention_mask at line 375.

Suggested change
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
batch_indices = torch.arange(input_ids.shape[0], device=self.HF_device)

Copilot uses AI. Check for mistakes.
if past_key_values is None:
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention_mask sum could include padding tokens (zeros) in the mask. If a sequence is completely padded (all zeros in attention_mask), then attention_mask.sum(1) would be 0, resulting in last_token_indices being -1. While this is technically a valid Python index (referring to the last element), it would give incorrect behavior for completely padded sequences. Consider adding validation or handling for this edge case.

Suggested change
last_token_indices = attention_mask.sum(1) - 1
token_counts = attention_mask.sum(1)
# Clamp to avoid negative indices when a sequence is fully padded (sum == 0).
last_token_indices = torch.clamp(token_counts - 1, min=0)

Copilot uses AI. Check for mistakes.
Comment on lines +398 to +400
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The attention_mask sum could include padding tokens (zeros) in the mask. If a sequence is completely padded (all zeros in attention_mask), then attention_mask.sum(1) would be 0, resulting in last_token_indices being -1. While this is technically a valid Python index (referring to the last element), it would give incorrect behavior for completely padded sequences. Consider adding validation or handling for this edge case.

Suggested change
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
seq_lengths = attention_mask.sum(1)
# Handle fully padded sequences (sum == 0) to avoid negative indices (-1)
if torch.any(seq_lengths == 0):
seq_lengths = seq_lengths.clone()
seq_lengths[seq_lengths == 0] = 1
last_token_indices = seq_lengths - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]

Copilot uses AI. Check for mistakes.
Comment on lines +315 to +324
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :]
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
# Assume no padding in incremental decoding steps
e_t = outputs.hidden_states[0][:, -1, :]
last_hidden = outputs.hidden_states[-1][:, -1, :]
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation inside this if block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Lines 315-320 should be indented with 12 spaces (8 base + 4 for the if block) rather than 13 spaces.

Suggested change
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :]
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
# Assume no padding in incremental decoding steps
e_t = outputs.hidden_states[0][:, -1, :]
last_hidden = outputs.hidden_states[-1][:, -1, :]
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :]
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
# Assume no padding in incremental decoding steps
e_t = outputs.hidden_states[0][:, -1, :]
last_hidden = outputs.hidden_states[-1][:, -1, :]

Copilot uses AI. Check for mistakes.
Comment on lines +315 to +324
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :]
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
# Assume no padding in incremental decoding steps
e_t = outputs.hidden_states[0][:, -1, :]
last_hidden = outputs.hidden_states[-1][:, -1, :]
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation inside this if block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Lines 323-324 should be indented with 12 spaces (8 base + 4 for the else block) rather than 13 spaces.

Suggested change
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :]
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
# Assume no padding in incremental decoding steps
e_t = outputs.hidden_states[0][:, -1, :]
last_hidden = outputs.hidden_states[-1][:, -1, :]
# Identify last token index
# attention_mask (at this point, if past is None, it is just original mask)
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :]
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
# Assume no padding in incremental decoding steps
e_t = outputs.hidden_states[0][:, -1, :]
last_hidden = outputs.hidden_states[-1][:, -1, :]

Copilot uses AI. Check for mistakes.
Comment on lines +398 to +402
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
last_hidden = outputs.hidden_states[-1][:, -1, :]
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The indentation inside this if block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Lines 398-400 should be indented with 12 spaces (8 base + 4 for the if block) rather than 13 spaces.

Suggested change
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
last_hidden = outputs.hidden_states[-1][:, -1, :]
last_token_indices = attention_mask.sum(1) - 1
batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device)
last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :]
else:
last_hidden = outputs.hidden_states[-1][:, -1, :]

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants