[Bug] Fix TRTLLM Block FP8 MoE Monolithic#36296
[Bug] Fix TRTLLM Block FP8 MoE Monolithic#36296robertgshaw2-redhat merged 3 commits intovllm-project:mainfrom
Conversation
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
There was a problem hiding this comment.
Code Review
This PR fixes a dtype mismatch bug for e_score_correction_bias in the TRT-LLM FP8 MoE monolithic path. The change in _apply_per_block is correct. I've identified a similar issue in _apply_per_tensor that is not covered by this PR and have left a comment with a suggested fix. Addressing this will ensure consistency and prevent potential runtime errors.
|
Can you try |
|
@ProExpertProg I am experiencing OOM issue on top of tree. I can roll back to an older commit and test but it would be best if we could test on top-of-tree. |
|
Tested your commit on top of my branch in #36551, that's good enough for me:
|
|
we need to backport this to 0.17.1 |
There was a problem hiding this comment.
LGTM, can we add a note in the PR or comment in the code saying what determines the e_score_correction_bias datatype? It's typically based on the checkpoint and kernel checkpoint, so hardcoding with hidden_states or router_logits, could work for a few models while not work for the others.
For example deepseek uses -
model.layers.3.mlp.gate.e_score_correction_bias | [256] | F32
|
Talked to Pavani offline. I think we could merge this patch for now. but it may be worth checking whether we should instead make change to the kernel side to support the built-in dtype for |
|
This PR on B200:
|
|
do we know which PR introduced this issue? the monolithic refactor did not change the behavior here AFAICT |
ah I see, its because now the hidden states are quantized. I dont understand why this is not caught by the tests, we run this for Qwen3 FP8 Block on B200 |
|
oh, its because it does not have e_score_correction_bias |
| router_logits = router_logits.to(torch.float32) | ||
|
|
||
| if e_score_correction_bias is not None: | ||
| e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype) |
There was a problem hiding this comment.
this works, but is a hack. Do we know what dtype is required by the kernel here? Is it just bf16?
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Head branch was pushed to by a user without write access
|
The reason for this regression per @robertgshaw2-redhat 's investigation is because the MoE refactor changes where the input quantization happens: Based on discussion, we remove the casting for the routing bias. |
|
we discovered this cast is not needed due to trtllm kernel supporting fp32 inputs now. We originally added this cast here - ef28354 but this is not needed anymore |
|
|
LGTM, thanks for the fix |
|
verified same accuracy with FP32 and Bf16 |
84e436e
into
vllm-project:main
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> (cherry picked from commit 84e436e)
Purpose
Fix #36295
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.