Skip to content

[plugin][sglang] register attn backend to sgl#5

Closed
gbyu-amd wants to merge 37 commits intozejun/plugin_for_atom_1223from
guanbao/sgl_attn_backend
Closed

[plugin][sglang] register attn backend to sgl#5
gbyu-amd wants to merge 37 commits intozejun/plugin_for_atom_1223from
guanbao/sgl_attn_backend

Conversation

@gbyu-amd
Copy link
Collaborator

@gbyu-amd gbyu-amd commented Feb 11, 2026

There is still an issue with cuda graph mode. Looking into it.
image

Launch command

We test with Qwen3-235B-A22B-Instruct-2507-FP8 on MI355.

ATOM + Sglang

export AITER_ROPE_FUSED_QKNORM=1

model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8
python3 -m sglang.launch_server \
    --model-path $model_path \
    --host localhost \
    --port 8000 \
    --trust-remote-code \
    --tensor-parallel-size 8 \
    --expert-parallel-size 8 \
    --kv-cache-dtype fp8_e4m3 \
    --mem-fraction-static 0.8 \
    --page-size 1024 \
    --disable-cuda-graph \
    --cuda-graph-max-bs 128 \
    --model-impl atom

Sglang

model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8
python3 -m sglang.launch_server \
    --model-path $model_path \
    --host localhost \
    --port 8000 \
    --trust-remote-code \
    --tensor-parallel-size 8 \
    --expert-parallel-size 8 \
    --kv-cache-dtype fp8_e4m3 \
    --mem-fraction-static 0.8 \
    --cuda-graph-max-bs 128

Accuracy

ATOM + Sglang

local-completions ({'model': '/data/models/Qwen3-235B-A22B-Instruct-2507-FP8', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 64, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 3, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match||0.8916|±  |0.0086|
|     |       |strict-match    |     3|exact_match||0.8749|±  |0.0091|

Sglang

local-completions ({'model': '/data/models/Qwen3-235B-A22B-Instruct-2507-FP8', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 64, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 3, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match||0.8954|±  |0.0084|
|     |       |strict-match    |     3|exact_match||0.8779|±  |0.0090|

Performance

ATOM + Sglang (eager mode)

Input_Tokens Output_Tokens Max_Concurrency Num_Prompts Mean_TTFT_ms Mean_TPOT_ms OutToken_Throughput TotToken_Throughput
1024 1024 4 40 232.22 49.18 81.04 162.08
1024 1024 8 80 322.96 49.53 160.64 321.29
1024 1024 16 160 460.02 54.01 294.05 588.09
1024 1024 32 320 756.8 55.15 573.03 1146.06
1024 1024 64 640 1209.78 42.88 1453.59 2907.17
1024 1024 128 1280 1998.49 44.43 2760.51 5521.02

Sglang

Too slow to finish benchmarking. The gsm8k test takes almost 30 mins, while ATOM + Sglang only takes 3 ~ 4 mins.
image

zejunchen-zejun and others added 30 commits February 9, 2026 21:57
framework

Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
zejunchen-zejun and others added 3 commits February 10, 2026 18:20
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
atom/config.py Outdated
), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1"
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = get_hf_config(self.model)
if is_plugin_mode():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we plan to follow the ATOM main branch code for loading hf_config from model.
Here you can have a dedicated PR to ATOM main to make here code be compatible with the different transformers version.

# Add qk-norm
q = self.q_norm(q)
k = self.k_norm(k)
if is_sglang():
Copy link
Owner

@zejunchen-zejun zejunchen-zejun Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the self.forward_sgl_plugin_mode path can be included into the RadixAttention forward method, otherwise when we enable more models, all of those models need to add sglang forward path

Here is the RadixAttention forward method interface, it has kwargs and we can pass arguments

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        positions: torch.Tensor = None,
        q_scale: Optional[torch.Tensor] = None,
        **kwargs,
    ):

key,
value,
forward_batch=forward_batch,
save_kv_cache=not self.use_aiter_rope_fused_qknorm,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the save_kv_cache is True, the sglang will call official kernel to save kv right?

@zejunchen-zejun
Copy link
Owner

zejunchen-zejun commented Feb 12, 2026

Hi, @gbyu-amd
Yajie said they have upstreamed some optimizations for Qwen-serial models into SGLang, so the performance baseline of SGLang maybe not so bad. Please double confirm with Yajie.

If the cuda graph capture issue cannot be easily solved, don't worry, we can mark it as known limitation

Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
@zejunchen-zejun zejunchen-zejun force-pushed the zejun/plugin_for_atom_1223 branch from a051118 to 28017a7 Compare February 27, 2026 01:52
@zejunchen-zejun
Copy link
Owner

The destination branch is force rebased, so the destination commit hash is changed.
We need to close this PR and kick off a new one to solve the bad conflict.
new one: #6

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants