Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions csrc/models/llama/llama_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,34 @@ LlamaMLP::LlamaMLP(std::shared_ptr<infinilm::config::ModelConfig> model_config,
}

infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const {
// 1. Project to gate and up
auto hidden_states_mutable = hidden_states;
auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable);
infinicore::Device::Type dev_type = hidden_states->device().getType();
if(dev_type == infinicore::Device::Type::MOORE){
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我感觉Device相关的判断理应放到InfiniCore中,不应该放在推理框架层

// 1. Project to a single combined gate_up tensor
auto hidden_states_mutable = hidden_states;
auto gate_up = gate_up_proj_->forward(hidden_states_mutable);

// 2. Apply SwiGLU: silu(gate) * up
// Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
// So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up
auto intermediate = infinicore::op::swiglu(up, gate);
// 2. Apply the fused silu_and_mul operator
// applies SiLU to the first half, and multiplies it by the second half.
// Mathematically equivalent to: result = SiLU(gate_up[..., :d]) * gate_up[..., d:]
auto intermediate = infinicore::op::silu_and_mul(gate_up);

// 3. Project down
auto output = down_proj_->forward(intermediate);
// 3. Project down
auto output = down_proj_->forward(intermediate);
return output;
} else{
// 1. Project to gate and up
auto hidden_states_mutable = hidden_states;
auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable);

return output;
// 2. Apply SwiGLU: silu(gate) * up
// Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up
// So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up
auto intermediate = infinicore::op::swiglu(up, gate);

// 3. Project down
auto output = down_proj_->forward(intermediate);
return output;
}
}

} // namespace infinilm::models::llama
43 changes: 43 additions & 0 deletions examples/jiuge.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def get_args():
default=1.0,
help="sampling temperature",
)
parser.add_argument(
"--warmup",
action="store_true",
help="Perform a warmup run before benchmarking/inference."
)

return parser.parse_args()

Expand Down Expand Up @@ -236,6 +241,44 @@ def test(

model.reset_cache(cache_config)

# ---------------------------------------------------------------------------- #
# Warmup
# ---------------------------------------------------------------------------- #
if args.warmup:
warmup_steps = 1

# Choose a length that approximates the real workload.
# It should be long enough to trigger the correct kernel paths,
# but not so long that warmup becomes unnecessarily expensive.
avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list))

# Use truncated versions of real prompts for warmup
warmup_ids = [
ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids
for ids in input_ids_list
]

input_ids_infini = infinicore.from_list(warmup_ids)

print("=================== warmup start ===================")

for _ in range(warmup_steps):
_ = model.generate(
input_ids_infini,
GenerationConfig(
max_new_tokens=2, # warmup decode kernel
temperature=temperature,
top_k=top_k,
top_p=top_p,
),
_measure_and_log_time=False,
)

print("=================== warmup done ====================")

# Reset KV cache
model.reset_cache(cache_config)

# ---------------------------------------------------------------------------- #
# Generate
# ---------------------------------------------------------------------------- #
Expand Down