-
Notifications
You must be signed in to change notification settings - Fork 4.2k
[feat] Support extra_buffer in Mamba2-based models
#15829
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[feat] Support extra_buffer in Mamba2-based models
#15829
Conversation
Summary of ChangesHello @roikoren755, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for extra_buffer in Mamba2-based models, enabling the use of both radix cache and the overlap scheduler. The changes are extensive, touching model implementations, attention backends, scheduling logic, and adding new tests. The core idea is to generalize the handling of chunk sizes and enable tracking of intermediate states for Mamba2 models, similar to what was done for FLA-based models. The code is well-structured, and the addition of tests for the new functionality is commendable. I have a couple of suggestions for improving code clarity and maintainability.
| if ( | ||
| intermediate_states is not None | ||
| and forward_batch.mamba_track_mask is not None | ||
| and forward_batch.mamba_track_mask.any() | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if return_intermediate_states: | ||
| if return_varlen_states: | ||
| varlen_states = rest[0] | ||
| if return_final_states: | ||
| return states, final_states, varlen_states | ||
| else: | ||
| return states, varlen_states | ||
| else: | ||
| if return_final_states: | ||
| return states, final_states | ||
| else: | ||
| return states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The nested if statements for handling return values based on different flags can be a bit hard to follow. Refactoring this block to a flatter structure would improve readability and maintainability.
if return_intermediate_states:
if not return_final_states and not return_varlen_states:
return states
if not return_final_states and return_varlen_states:
return states, rest[0]
if return_final_states and not return_varlen_states:
return states, final_states
# return_final_states and return_varlen_states
return states, final_states, rest[0]de24f9e to
1baf72d
Compare
41412f3 to
9244a86
Compare
| lens_to_track = ( | ||
| forward_batch.mamba_track_seqlens - forward_batch.extend_prefix_lens | ||
| ) | ||
| mamba_cache_chunk_size = get_global_server_args().mamba_cache_chunk_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two variables:
- FLA_CHUNK_SIZE
- mamba_cache_chunk_size
IIUC, there's no reason to touch anything related to mamba_cache_chunk_size? IIUC, you should only replace all usages of FLA_CHUNK_SIZE with mamba_chunk_size (i.e. the CHUNK_SIZE for mamba2's backend, also maybe use a better name like <backend>_CHUNK_SIZE or MAMBA2_CHUNK_SIZE) for mamba2 models
| mamba_track_indices_cpu: List[int], | ||
| mamba_track_seqlens_cpu: List[int], | ||
| ): | ||
| mamba_track_interval = get_global_server_args().mamba_track_interval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, you don't need to touch any logic related to mamba_track_interval?
See https://github.com/sgl-project/sglang/pull/15829/files#r2691660148
I feel this is not crashing just because mamba_track_interval happens to be 256?
9244a86 to
1ce8415
Compare
|
@hanming-lu The current implementation for I'm still seeing some intermittent errors in some of the tests, with the scheduler raising an error that a memory leak is detected. Not sure if it's just my setup or if it will happen in the CI as well... |
Signed-off-by: Roi Koren <roik@nvidia.com>
… tests and clean that file up Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
Signed-off-by: Roi Koren <roik@nvidia.com>
1ce8415 to
12e4649
Compare
Motivation
Recent updates to Qwen3-Next models enabled running them with both radix cache and overlap scheduler enabled. This PR does the same for Mamba2-based models.
Modifications
Optionally return intermediate states in the
mamba_chunk_scan_combinedprefill kernel.Fix writing locations of intermediate states in
selective_state_updatedecode kernel.Update MambaMixer2 to return prefill intermediate states, and provide SpecDec kernels with missing (and newly introduced) intermediate state writing locations, and update conv state for prefix caching in the
extra_buffercode path.Add missing tracking tensors to Mamba2Metadata.
Update MambaAttnBackendBase to work with non-FLA_CHUNK_SIZE chunk sizes too.
Update Mamba2AttnBackend for
extra_buffertracking.Update ScheduleBatch and MambaRadixCache to work with non-FLA_CHUNK_SIZE chunk sizes too.
Update NemotronH and FalconH1 models to pass newly required forward batch to the attention backend.
Update ServerArgs initialization to allow running Mamba2-based models with both radix cache and overlap scheduler enabled.
Added tests with
--mamba-scheduler-strategy extra_bufferfor nvidia/NVIDIA-Nemotron-Nano-v2-9B, also in P/D disaggregation tests.Refactored and cleaned up Qwen3-Next tests for code re-use.
Accuracy Tests
All existing tests pass without accuracy degradation.
Newly added tests pass with the same accuracy.
Benchmarking and Profiling
Checklist