Skip to content

Conversation

@uygnef
Copy link
Collaborator

@uygnef uygnef commented Jan 27, 2026

Move hidden_state sharding into dataset with mmap loading to cut memory and I/O

Motivation

Modifications

Related Issues

Accuracy Test

Benchmark & Profiling

Checklist

…ataset with mmap loading to cut memory and I/O
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @uygnef, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization to the offline dataloader, particularly for training with sequence parallelism. By integrating memory-mapped loading and moving the hidden state sharding logic into the dataset itself, the system can now handle larger sequence lengths more efficiently, drastically reducing memory footprint and improving data loading performance during training. These changes ensure that each processing unit only handles the necessary data segments, streamlining the overall training pipeline.

Highlights

  • Memory-Mapped Data Loading: The OfflineEagle3Dataset now utilizes memory-mapped loading (mmap=True) for hidden states, significantly reducing initial RAM consumption and I/O overhead by only loading data into memory when accessed.
  • Integrated Sequence Parallelism (SP) Sharding: Hidden state sharding for Sequence Parallelism (SP) is now performed directly within the OfflineEagle3Dataset during data loading. This ensures each GPU only loads its specific chunk of hidden states, optimizing VRAM and RAM usage.
  • Dynamic Position ID Generation for USP: The eagle3.py core logic has been updated to dynamically generate position IDs for the USP (Unified Sequence Parallelism) attention backend, ensuring correct positional embeddings for sharded sequences.
  • Data Collator Adaptation: The DataCollatorForEagle3 has been modified to handle pre-sharded hidden states when SP is active, concatenating them directly without additional padding, as sharding is now managed by the dataset.
  • USP Configuration Validation: New validation checks have been added to train_eagle3.py to enforce batch_size=1 and sp_ring_size * sp_ulysses_size > 1 when using the USP attention backend, preventing misconfigurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimizations to the offline data loader for sequence parallelism (SP) by moving hidden state sharding into the dataset with memory-mapped file loading. This aims to reduce memory usage and I/O overhead. The changes include argument validation, modifications to the Eagle3 model's forward pass, and significant updates to the data preprocessing pipeline to support sharded hidden states. The code has been reviewed, and suggestions have been provided to address potential issues with batch size constraints and to improve code clarity.

Comment on lines +535 to +536
if "aux_hidden_state" not in data or data["aux_hidden_state"] is None:
raise KeyError("aux_hidden_state is required for OfflineEagle3Dataset")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This KeyError exception is raised when aux_hidden_state is not found in the data. This is a critical error because the dataset cannot be processed without this key. Consider adding a check earlier in the code to ensure that the data contains this key, and provide a more informative error message if it is missing.

Comment on lines +346 to +349
if args.attention_backend == "usp" and args.batch_size != 1:
raise ValueError(
f"USP only supports batch_size=1, got batch_size={args.batch_size}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is good for ensuring the batch size is compatible with USP, but it might be better to disable USP if the batch size is not 1, rather than raising an error. This would give the user a more graceful way to handle the situation.

Comment on lines 224 to +228
for idx in range(self.length):
target_p = target_p_padded[:, idx : idx + seq_length, :]
if self.attention_backend == "usp":
target_slice_len = global_seq_length
else:
target_slice_len = seq_length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It seems that target_slice_len is only used for slicing target_p. However, target_p is not used when self.attention_backend == "usp". This means that the else branch is never taken. Consider removing the else branch to simplify the code.

Comment on lines +120 to +124
if self.sp_degree > 1:
batch["hidden_state"] = torch.cat(
[item["hidden_state"] for item in features]
)
batch["target"] = torch.cat([item["target"] for item in features])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This condition checks if self.sp_degree > 1 to determine whether to concatenate hidden states and targets directly or to pad them first. However, the padding logic seems unnecessary when sequence parallelism is enabled, as the tensors are already expected to be of the same length due to sharding. Consider removing the else branch to simplify the code and potentially improve performance.

Comment on lines 1051 to +1060
if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding):
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[
self.ring_rank
].clone()
if position_ids.shape[2] != q_len:
position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[
self.ring_rank
].clone()
else:
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[
self.ring_rank
].clone()
if position_ids.shape[1] != q_len:
position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[
self.ring_rank
].clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conditional logic checks if position_ids.shape[2] or position_ids.shape[1] is not equal to q_len before chunking. It's not clear why the chunking operation is skipped if the shapes are equal. Consider adding a comment to explain the purpose of this check and ensure that it's the intended behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant