Skip to content

[Bug] Fix TRTLLM Block FP8 MoE Monolithic#36296

Merged
robertgshaw2-redhat merged 3 commits intovllm-project:mainfrom
wzhao18:wzhao/fix-fs-r1-trtllm-moe
Mar 11, 2026
Merged

[Bug] Fix TRTLLM Block FP8 MoE Monolithic#36296
robertgshaw2-redhat merged 3 commits intovllm-project:mainfrom
wzhao18:wzhao/fix-fs-r1-trtllm-moe

Conversation

@wzhao18
Copy link
Contributor

@wzhao18 wzhao18 commented Mar 7, 2026

Purpose

Fix #36295

Test Plan

vllm serve deepseek-ai/DeepSeek-R1 -tp 8
lm_eval --model local-completions --model_args "base_url=http://0.0.0.0:8000/v1/completions,max_length=8192,tokenized_requests=False,tokenizer_backend=None,num_concurrent=32" --tasks gsm8k --num_fewshot 5

Test Result

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.953|±  |0.0058|
|     |       |strict-match    |     5|exact_match|↑  |0.953|±  |0.0058|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
@mergify mergify bot added nvidia bug Something isn't working labels Mar 7, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR fixes a dtype mismatch bug for e_score_correction_bias in the TRT-LLM FP8 MoE monolithic path. The change in _apply_per_block is correct. I've identified a similar issue in _apply_per_tensor that is not covered by this PR and have left a comment with a suggested fix. Addressing this will ensure consistency and prevent potential runtime errors.

@ProExpertProg
Copy link
Collaborator

Can you try vllm serve deepseek-ai/DeepSeek-V3 -tp=8 as well?

@ProExpertProg ProExpertProg added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2026
@wzhao18
Copy link
Contributor Author

wzhao18 commented Mar 10, 2026

@ProExpertProg I am experiencing OOM issue on top of tree. I can roll back to an older commit and test but it would be best if we could test on top-of-tree.

@ProExpertProg
Copy link
Collaborator

Tested your commit on top of my branch in #36551, that's good enough for me:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.97 ± 0.0171
strict-match 5 exact_match 0.97 ± 0.0171

@robertgshaw2-redhat
Copy link
Collaborator

we need to backport this to 0.17.1

Copy link
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, can we add a note in the PR or comment in the code saying what determines the e_score_correction_bias datatype? It's typically based on the checkpoint and kernel checkpoint, so hardcoding with hidden_states or router_logits, could work for a few models while not work for the others.
For example deepseek uses -

model.layers.3.mlp.gate.e_score_correction_bias | [256] | F32

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 10, 2026
@wzhao18
Copy link
Contributor Author

wzhao18 commented Mar 10, 2026

Talked to Pavani offline. I think we could merge this patch for now. but it may be worth checking whether we should instead make change to the kernel side to support the built-in dtype for e_score_correction_bias from the model checkpoint.

@robertgshaw2-redhat robertgshaw2-redhat changed the title [Bug] Fix TRTLLM FP8 MoE Monolithic [Bug] Fix TRTLLM Block FP8 MoE Monolithic Mar 10, 2026
@ProExpertProg ProExpertProg enabled auto-merge (squash) March 10, 2026 17:58
@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Mar 10, 2026

This PR on B200:

local-completions (pretrained=deepseek-ai/DeepSeek-V3,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.95 ± 0.006
strict-match 5 exact_match 0.95 ± 0.006

@robertgshaw2-redhat
Copy link
Collaborator

do we know which PR introduced this issue?

the monolithic refactor did not change the behavior here AFAICT

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Mar 10, 2026

do we know which PR introduced this issue?

the monolithic refactor did not change the behavior here AFAICT

ah I see, its because now the hidden states are quantized. I dont understand why this is not caught by the tests, we run this for Qwen3 FP8 Block on B200

@robertgshaw2-redhat
Copy link
Collaborator

oh, its because it does not have e_score_correction_bias

router_logits = router_logits.to(torch.float32)

if e_score_correction_bias is not None:
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works, but is a hack. Do we know what dtype is required by the kernel here? Is it just bf16?

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
auto-merge was automatically disabled March 10, 2026 20:58

Head branch was pushed to by a user without write access

@wzhao18 wzhao18 requested review from bnellnm and pavanimajety March 10, 2026 20:58
@wzhao18
Copy link
Contributor Author

wzhao18 commented Mar 10, 2026

The reason for this regression per @robertgshaw2-redhat 's investigation is because the MoE refactor changes where the input quantization happens:

# prior
bias = bias.to(x.dtype)
x_q, x_scale = quantize(x)
kernel()

# after:
x << input to function is quantized
bias = bias.to(x.dtype)  # oops, cast to fp8
kernel()

Based on discussion, we remove the casting for the routing bias.

@robertgshaw2-redhat
Copy link
Collaborator

we discovered this cast is not needed due to trtllm kernel supporting fp32 inputs now. We originally added this cast here - ef28354

but this is not needed anymore

@ProExpertProg
Copy link
Collaborator

local-completions (pretrained=deepseek-ai/DeepSeek-V3,base_url=http://0.0.0.0:8000/v1/completions,num_concurrent=50,max_retries=3), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9462 ± 0.0062
strict-match 5 exact_match 0.9462 ± 0.0062

@robertgshaw2-redhat
Copy link
Collaborator

LGTM, thanks for the fix

@robertgshaw2-redhat
Copy link
Collaborator

verified same accuracy with FP32 and Bf16

@robertgshaw2-redhat robertgshaw2-redhat merged commit 84e436e into vllm-project:main Mar 11, 2026
54 of 60 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 11, 2026
khluu pushed a commit that referenced this pull request Mar 11, 2026
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
(cherry picked from commit 84e436e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Bug]: Deepseek-v3 fails on 8xB200 in v0.17.0 (including eager) [Bug]: Deepseek R1 TRTLLM FP8 MoE produces garbage output

5 participants