[plugin][sglang] register attn backend to sgl#5
[plugin][sglang] register attn backend to sgl#5gbyu-amd wants to merge 37 commits intozejun/plugin_for_atom_1223from
Conversation
framework Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
Signed-off-by: zejunchen-zejun <zejun.chen@amd.com>
atom/config.py
Outdated
| ), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1" | ||
| assert 1 <= self.tensor_parallel_size <= 8 | ||
| self.hf_config = get_hf_config(self.model) | ||
| if is_plugin_mode(): |
There was a problem hiding this comment.
Here we plan to follow the ATOM main branch code for loading hf_config from model.
Here you can have a dedicated PR to ATOM main to make here code be compatible with the different transformers version.
atom/models/qwen3_moe.py
Outdated
| # Add qk-norm | ||
| q = self.q_norm(q) | ||
| k = self.k_norm(k) | ||
| if is_sglang(): |
There was a problem hiding this comment.
It seems the self.forward_sgl_plugin_mode path can be included into the RadixAttention forward method, otherwise when we enable more models, all of those models need to add sglang forward path
Here is the RadixAttention forward method interface, it has kwargs and we can pass arguments
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
positions: torch.Tensor = None,
q_scale: Optional[torch.Tensor] = None,
**kwargs,
):
atom/model_ops/radix_attention.py
Outdated
| key, | ||
| value, | ||
| forward_batch=forward_batch, | ||
| save_kv_cache=not self.use_aiter_rope_fused_qknorm, |
There was a problem hiding this comment.
if the save_kv_cache is True, the sglang will call official kernel to save kv right?
|
Hi, @gbyu-amd If the cuda graph capture issue cannot be easily solved, don't worry, we can mark it as known limitation |
a051118 to
28017a7
Compare
|
The destination branch is force rebased, so the destination commit hash is changed. |
There is still an issue with cuda graph mode. Looking into it.

Launch command
We test with Qwen3-235B-A22B-Instruct-2507-FP8 on MI355.
ATOM + Sglang
Sglang
model_path=/data/models/Qwen3-235B-A22B-Instruct-2507-FP8 python3 -m sglang.launch_server \ --model-path $model_path \ --host localhost \ --port 8000 \ --trust-remote-code \ --tensor-parallel-size 8 \ --expert-parallel-size 8 \ --kv-cache-dtype fp8_e4m3 \ --mem-fraction-static 0.8 \ --cuda-graph-max-bs 128Accuracy
ATOM + Sglang
local-completions ({'model': '/data/models/Qwen3-235B-A22B-Instruct-2507-FP8', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 64, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 3, batch_size: 1 |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.8916|± |0.0086| | | |strict-match | 3|exact_match|↑ |0.8749|± |0.0091|Sglang
local-completions ({'model': '/data/models/Qwen3-235B-A22B-Instruct-2507-FP8', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 64, 'max_retries': 3, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 3, batch_size: 1 |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 3|exact_match|↑ |0.8954|± |0.0084| | | |strict-match | 3|exact_match|↑ |0.8779|± |0.0090|Performance
ATOM + Sglang (eager mode)
Sglang
Too slow to finish benchmarking. The gsm8k test takes almost 30 mins, while ATOM + Sglang only takes 3 ~ 4 mins.
