Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def get_quant_config(config: PretrainedConfig) -> QuantizationConfig:


def get_hf_config(model: str) -> PretrainedConfig:
print("model", model, flush=True)
config_dict, _ = PretrainedConfig.get_config_dict(
model,
)
Expand Down Expand Up @@ -610,10 +611,12 @@ def __post_init__(self):
), f"kv_cache_block_size ({self.kv_cache_block_size}) must be a multiple of 16 or 1"
assert 1 <= self.tensor_parallel_size <= 8
self.hf_config = get_hf_config(self.model)
print("hf_config", self.hf_config, flush=True)
if not hasattr(self.hf_config, "rope_parameters"):
# Compatible with both transformers < 5
rope_params = getattr(self.hf_config, "rope_scaling", {})
rope_params["rope_theta"] = self.hf_config.rope_theta
if rope_params is not None:
rope_params["rope_theta"] = self.hf_config.rope_theta
self.hf_config.rope_parameters = rope_params

self.generation_config = get_generation_config(self.model)
Expand Down
6 changes: 6 additions & 0 deletions atom/model_ops/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def __init__(
per_layer_sliding_window: Optional[int] = None,
rotary_emb: Optional[torch.nn.Module] = None,
prefix: Optional[str] = None,
q_norm: Optional[torch.nn.Module] = None,
k_norm: Optional[torch.nn.Module] = None,
**kwargs,
):
super().__init__(
Expand All @@ -49,6 +51,10 @@ def __init__(
prefix=prefix,
**kwargs,
)
self.q_norm = q_norm
self.k_norm = k_norm
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads

if is_sglang():
self.rotary_emb = rotary_emb
Expand Down
36 changes: 24 additions & 12 deletions atom/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from aiter.dist.communication_op import tensor_model_parallel_all_reduce
from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size
from aiter.dist.parallel_state import get_pp_group, get_tensor_model_parallel_world_size, is_global_first_rank

# from atom.model_ops.rotary_embedding import get_rope
from aiter.rotary_embedding import get_rope
Expand Down Expand Up @@ -196,14 +196,14 @@ def __init__(
base=rope_theta,
rope_scaling=rope_scaling,
)
if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION:
cos, sin = self.rotary_emb.cos_cache, self.rotary_emb.sin_cache
joint_cache = torch.cat((cos, sin), dim=-1)
self.rotary_emb.register_buffer(
"cos_sin_cache",
joint_cache.view(joint_cache.size(0), self.head_dim),
persistent=False,
)
# if ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION:
# cos, sin = self.rotary_emb.cos_cache, self.rotary_emb.sin_cache
# joint_cache = torch.cat((cos, sin), dim=-1)
# self.rotary_emb.register_buffer(
# "cos_sin_cache",
# joint_cache.view(joint_cache.size(0), self.head_dim),
# persistent=False,
# )

self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
Expand Down Expand Up @@ -263,8 +263,12 @@ def __init__(self, atom_config=None, layer_num: int = 0, prefix: str = "") -> No
config = self.atom_config.hf_config
self.hidden_size = config.hidden_size
rope_params = config.rope_parameters
rope_theta = rope_params["rope_theta"]
rope_scaling = rope_params
if rope_params is None:
rope_theta = 10000
rope_scaling = None
else:
rope_theta = rope_params["rope_theta"]
rope_scaling = rope_params
kv_cache_dtype = atom_config.kv_cache_dtype
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# DecoderLayers are created with `make_layers` which passes the prefix
Expand Down Expand Up @@ -328,11 +332,14 @@ def forward(
**model_kwargs: dict[str, Any] | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
# if is_global_first_rank():
# print(f"Layer {self.layer_idx} input hidden_states, before input_layer_norm: {hidden_states.norm().item():.4f}", flush=True)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)

hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
Expand All @@ -345,7 +352,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
# @support_torch_compile
class Qwen3MoeModel(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -410,7 +417,12 @@ def forward(
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]

layer_idx = 0
for layer in self.layers[self.start_layer : self.end_layer]:
if is_global_first_rank():
print("=" * 20 + f" Layer {layer_idx} " + "=" * 20, flush=True)
print(f"Layer {layer_idx} input hidden_states: {hidden_states.norm().item():.4f}", flush=True)
layer_idx += 1
hidden_states, residual = layer(
positions, hidden_states, residual, **model_kwargs
)
Expand Down
4 changes: 3 additions & 1 deletion atom/plugin/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _generate_atom_config_from_sglang_config(config: Any):
server_args: ServerArgs = prepare_server_args(sys.argv[1:])

sgl_model_config = SglangModelConfig.from_server_args(server_args)
print(f"sgl_model_config: {sgl_model_config.model_path}", flush=True)
sgl_model_opt_config = ModelOptConfig(
quant=server_args.modelopt_quant,
checkpoint_restore_path=server_args.modelopt_checkpoint_restore_path,
Expand Down Expand Up @@ -191,7 +192,8 @@ def _generate_atom_config_from_sglang_config(config: Any):
# force max num batched tokens to 16K because sgl doesn't have
# concept for max num batched tokens
return Config(
model=None,
model=sgl_model_config.model_path,
# model=sgl_model_config.model_path,
max_num_batched_tokens=16384,
max_num_seqs=server_args.max_running_requests,
max_model_len=server_args.context_length,
Expand Down
32 changes: 32 additions & 0 deletions launch_qwen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
set -x
export ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION=1
export AITER_ROPE_FUSED_QKNORM=1

# quick allreduce
export AITER_QUICK_REDUCE_QUANTIZATION=INT4
# model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-30B-A3B-Instruct-2507-FP8
# model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-VL-235B-A22B-Instruct-FP8
model_path=/mnt/raid0/pretrained_model/Qwen/Qwen3-235B-A22B-Instruct-2507-FP8


TORCHINDUCTOR_COMPILE_THREADS=128 CUDA_VISIBLE_DEVICES="0,1,2,3" python3 -m sglang.launch_server \
--model-path $model_path \
--host localhost \
--port 8000 \
--trust-remote-code \
--tensor-parallel-size 4 \
--expert-parallel-size 4 \
--kv-cache-dtype fp8_e4m3 \
--mem-fraction-static 0.7 \
--disable-cuda-graph \
--model-impl atom \
--page-size 1024 \
2>&1 | tee log.serve.log

# curl -X POST "http://localhost:8000/v1/completions" \
# -H "Content-Type: application/json" \
# -d '{
# "prompt": "The capital of China", "temperature": 0, "top_p": 1,
# "top_k": 0, "repetition_penalty": 1.0, "presence_penalty": 0, "frequency_penalty": 0,
# "stream": false, "ignore_eos": false, "n": 1, "seed": 123
# }'