diff --git a/csrc/models/llama/llama_mlp.cpp b/csrc/models/llama/llama_mlp.cpp index a3ab7859..282e2eca 100644 --- a/csrc/models/llama/llama_mlp.cpp +++ b/csrc/models/llama/llama_mlp.cpp @@ -71,19 +71,34 @@ LlamaMLP::LlamaMLP(std::shared_ptr model_config, } infinicore::Tensor LlamaMLP::forward(const infinicore::Tensor &hidden_states) const { - // 1. Project to gate and up - auto hidden_states_mutable = hidden_states; - auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); + infinicore::Device::Type dev_type = hidden_states->device().getType(); + if(dev_type == infinicore::Device::Type::MOORE){ + // 1. Project to a single combined gate_up tensor + auto hidden_states_mutable = hidden_states; + auto gate_up = gate_up_proj_->forward(hidden_states_mutable); - // 2. Apply SwiGLU: silu(gate) * up - // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up - // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up - auto intermediate = infinicore::op::swiglu(up, gate); + // 2. Apply the fused silu_and_mul operator + // applies SiLU to the first half, and multiplies it by the second half. + // Mathematically equivalent to: result = SiLU(gate_up[..., :d]) * gate_up[..., d:] + auto intermediate = infinicore::op::silu_and_mul(gate_up); - // 3. Project down - auto output = down_proj_->forward(intermediate); + // 3. Project down + auto output = down_proj_->forward(intermediate); + return output; + } else{ + // 1. Project to gate and up + auto hidden_states_mutable = hidden_states; + auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); - return output; + // 2. Apply SwiGLU: silu(gate) * up + // Note: swiglu kernel expects (up, gate) and computes gate * sigmoid(gate) * up + // So we pass (up, gate) to get the correct result: gate * sigmoid(gate) * up + auto intermediate = infinicore::op::swiglu(up, gate); + + // 3. Project down + auto output = down_proj_->forward(intermediate); + return output; + } } } // namespace infinilm::models::llama diff --git a/examples/jiuge.py b/examples/jiuge.py index 8612fc26..3b515d78 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -131,6 +131,11 @@ def get_args(): default=1.0, help="sampling temperature", ) + parser.add_argument( + "--warmup", + action="store_true", + help="Perform a warmup run before benchmarking/inference." + ) return parser.parse_args() @@ -236,6 +241,44 @@ def test( model.reset_cache(cache_config) + # ---------------------------------------------------------------------------- # + # Warmup + # ---------------------------------------------------------------------------- # + if args.warmup: + warmup_steps = 1 + + # Choose a length that approximates the real workload. + # It should be long enough to trigger the correct kernel paths, + # but not so long that warmup becomes unnecessarily expensive. + avg_prompt_len = min(64, max(len(ids) for ids in input_ids_list)) + + # Use truncated versions of real prompts for warmup + warmup_ids = [ + ids[:avg_prompt_len] if len(ids) >= avg_prompt_len else ids + for ids in input_ids_list + ] + + input_ids_infini = infinicore.from_list(warmup_ids) + + print("=================== warmup start ===================") + + for _ in range(warmup_steps): + _ = model.generate( + input_ids_infini, + GenerationConfig( + max_new_tokens=2, # warmup decode kernel + temperature=temperature, + top_k=top_k, + top_p=top_p, + ), + _measure_and_log_time=False, + ) + + print("=================== warmup done ====================") + + # Reset KV cache + model.reset_cache(cache_config) + # ---------------------------------------------------------------------------- # # Generate # ---------------------------------------------------------------------------- #