Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,13 @@ install(
ARCHIVE DESTINATION lib
RUNTIME DESTINATION bin)

install(
TARGETS flatbuffers
EXPORT MllmTargets
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
RUNTIME DESTINATION bin)

if(MLLM_BUILD_SDK_C_BINDING)
install(
TARGETS MllmSdkC
Expand Down
38 changes: 21 additions & 17 deletions examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,6 @@

namespace mllm::models::qwen3 {

Tensor rotateHalf(Tensor x) { // NOLINT
// X is [x, x, x, D]
auto D = x.size(-1);
auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true);
auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true);
return nn::functional::concat({-x2, x1}, -1);
}

namespace ptq {

Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) {
Expand Down Expand Up @@ -112,6 +104,14 @@ Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch

} // namespace ptq

Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT
// X is [x, x, x, D]
auto D = x.size(-1);
auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true);
auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true);
return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1);
}

using vi32 = std::vector<int32_t>;
#define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32

Expand Down Expand Up @@ -232,14 +232,16 @@ class Qwen3Attention final : public nn::Module {
// [B, H, S, D]
auto cos = llm_embedding_cos.unsqueeze(1);
auto sin = llm_embedding_sin.unsqueeze(1);
query_states = ptq::QDQ(this,
ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq")
+ ptq::QDQ(this, rotateHalf(query_states) * sin, "q_rope_mul_1_output_qdq"),
"q_rope_add_0_output_qdq");
key_states = ptq::QDQ(this,
ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq")
+ ptq::QDQ(this, rotateHalf(key_states) * sin, "k_rope_mul_1_output_qdq"),
"k_rope_add_0_output_qdq");
query_states =
ptq::QDQ(this,
ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq")
+ ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"),
"q_rope_add_0_output_qdq");
key_states =
ptq::QDQ(this,
ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq")
+ ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"),
"k_rope_add_0_output_qdq");

// De-quantization and quantization again
key_states = key_states.to(kFloat32);
Expand Down Expand Up @@ -272,7 +274,9 @@ class Qwen3Attention final : public nn::Module {
auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq");
auto minus_value = Tensor::constant(-20, kFloat32);
minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq");
attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_min.addConstant(minus_value));
auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq");
attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_vv);
attn = ptq::QDQ(this, attn, "where_attn_qdq");
attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq");
auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq");
y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true);
Expand Down
4 changes: 4 additions & 0 deletions mllm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "App
endif()
endif()

# FIXME: @oreomaker Need to remove comma features in slice!
# Suppress comma-subscript warnings (deprecated C++ feature that will be removed in C++26)
target_compile_options(MllmRT PUBLIC -Wno-comma-subscript)

# ONLY APPLE CAN DO !
# Processing OpenMP
if(MLLM_KERNEL_USE_THREADS AND MLLM_KERNEL_THREADS_VENDOR_OPENMP)
Expand Down
Loading