diff --git a/CMakeLists.txt b/CMakeLists.txt index 298b412c0..fca470ee5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,6 +332,13 @@ install( ARCHIVE DESTINATION lib RUNTIME DESTINATION bin) +install( + TARGETS flatbuffers + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) + if(MLLM_BUILD_SDK_C_BINDING) install( TARGETS MllmSdkC diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index 9eed37267..a2d054bad 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -15,14 +15,6 @@ namespace mllm::models::qwen3 { -Tensor rotateHalf(Tensor x) { // NOLINT - // X is [x, x, x, D] - auto D = x.size(-1); - auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); - auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); - return nn::functional::concat({-x2, x1}, -1); -} - namespace ptq { Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { @@ -112,6 +104,14 @@ Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch } // namespace ptq +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + using vi32 = std::vector; #define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 @@ -232,14 +232,16 @@ class Qwen3Attention final : public nn::Module { // [B, H, S, D] auto cos = llm_embedding_cos.unsqueeze(1); auto sin = llm_embedding_sin.unsqueeze(1); - query_states = ptq::QDQ(this, - ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(query_states) * sin, "q_rope_mul_1_output_qdq"), - "q_rope_add_0_output_qdq"); - key_states = ptq::QDQ(this, - ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(key_states) * sin, "k_rope_mul_1_output_qdq"), - "k_rope_add_0_output_qdq"); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); // De-quantization and quantization again key_states = key_states.to(kFloat32); @@ -272,7 +274,9 @@ class Qwen3Attention final : public nn::Module { auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); auto minus_value = Tensor::constant(-20, kFloat32); minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); - attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_min.addConstant(minus_value)); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index 9df6b7741..fd796f95a 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -56,6 +56,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "App endif() endif() +# FIXME: @oreomaker Need to remove comma features in slice! +# Suppress comma-subscript warnings (deprecated C++ feature that will be removed in C++26) +target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) + # ONLY APPLE CAN DO ! # Processing OpenMP if(MLLM_KERNEL_USE_THREADS AND MLLM_KERNEL_THREADS_VENDOR_OPENMP) diff --git a/mllm/backends/cpu/kernels/common/fill-inl.hpp b/mllm/backends/cpu/kernels/common/fill-inl.hpp new file mode 100644 index 000000000..4c799daf6 --- /dev/null +++ b/mllm/backends/cpu/kernels/common/fill-inl.hpp @@ -0,0 +1,363 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +// NOTE: Do NOT use #pragma once here! +// Highway's foreach_target.h mechanism requires -inl.hpp files to be included +// multiple times, once for each target architecture (AVX3_DL, AVX10_2, etc.). + +#include +#include +#include "mllm/core/DataTypes.hpp" + +HWY_BEFORE_NAMESPACE(); +namespace mllm::cpu::common { // NOLINT +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_zeros_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec zero = hn::Zero(d); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(zero, d, dst + idx); } + + if (idx < count) { hn::StoreN(zero, d, dst + idx, count - idx); } +} + +// Specialization for types not supported by Highway SIMD, use memset +template +HWY_INLINE void fill_zeros_scalar(T* HWY_RESTRICT dst, size_t count) { + if constexpr (std::is_trivial_v) { + std::memset(dst, 0, count * sizeof(T)); + } else { + T zero_val{}; + for (size_t i = 0; i < count; ++i) { dst[i] = zero_val; } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_ones_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec one = hn::Set(d, static_cast(1)); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(one, d, dst + idx); } + + if (idx < count) { hn::StoreN(one, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_value_impl(T* HWY_RESTRICT dst, size_t count, T value) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec v = hn::Set(d, value); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(v, d, dst + idx); } + + if (idx < count) { hn::StoreN(v, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size, mllm_fp64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_int32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_uint32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_int64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_uint64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_int16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_uint16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_int8_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_uint8_t value) { + fill_value_impl(dst, size, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange (start, end, step) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_arange_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if (step == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + // Calculate the actual number of elements to fill + size_t n = 0; + if ((step > 0 && start < end) || (step < 0 && start > end)) { + mllm_fp32_t n_float = (end - start) / step; + if (n_float > 0) { + n = static_cast(std::ceil(n_float)); + if (step > 0) { + if (start + (n - 1) * step >= end) --n; + } else { + if (start + (n - 1) * step <= end) --n; + } + n = std::min(n, count); + } + } + + // Use SIMD for float types where we can vectorize the computation + if constexpr (std::is_same_v) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + + // Create increment vector: [0, 1, 2, 3, ...] * step + const hn::Vec step_vec = hn::Set(d, step); + const hn::Vec n_step_vec = hn::Set(d, step * static_cast(N)); + + // Create base offsets [0, 1, 2, 3, ...] + hn::Vec base = hn::Iota(d, 0); + base = hn::Mul(base, step_vec); + hn::Vec current_start = hn::Add(hn::Set(d, start), base); + + size_t idx = 0; + for (; idx + N <= n; idx += N) { + hn::StoreU(current_start, d, dst + idx); + current_start = hn::Add(current_start, n_step_vec); + } + + // Handle remaining elements + for (; idx < n; ++idx) { dst[idx] = static_cast(start + idx * step); } + } else { + // Scalar fallback for other types + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random (using LCG random number generator) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_random_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; // 2^31 + const mllm_fp32_t range = end - start; + + if (range == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + uint64_t state = seed; + state = (multiplier * state + increment) % modulus; + + for (size_t i = 0; i < count; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +} // namespace HWY_NAMESPACE +} // namespace mllm::cpu::common +HWY_AFTER_NAMESPACE(); diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp index 1ad3cee93..7e81adfdf 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp @@ -17,6 +17,7 @@ // Include all inline implementations here #include "mllm/backends/cpu/kernels/common/elewise-inl.hpp" +#include "mllm/backends/cpu/kernels/common/fill-inl.hpp" #if HWY_ONCE namespace mllm::cpu::common { @@ -69,11 +70,188 @@ HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 // GELU //===----------------------------------------------------------------------===// // HWY_EXPORT(gelu_fp32); -// +// // HWY_DLLEXPORT void call_gelu_fp32(mllm_fp32_t* out, const mllm_fp32_t* in, size_t n) { // HWY_DYNAMIC_DISPATCH(gelu_fp32)(out, in, n); // } +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_zeros_fp32); +HWY_EXPORT(fill_zeros_fp64); +HWY_EXPORT(fill_zeros_i32); +HWY_EXPORT(fill_zeros_u32); +HWY_EXPORT(fill_zeros_i64); +HWY_EXPORT(fill_zeros_u64); +HWY_EXPORT(fill_zeros_i16); +HWY_EXPORT(fill_zeros_u16); +HWY_EXPORT(fill_zeros_i8); +HWY_EXPORT(fill_zeros_u8); + +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_ones_fp32); +HWY_EXPORT(fill_ones_fp64); +HWY_EXPORT(fill_ones_i32); +HWY_EXPORT(fill_ones_u32); +HWY_EXPORT(fill_ones_i64); +HWY_EXPORT(fill_ones_u64); +HWY_EXPORT(fill_ones_i16); +HWY_EXPORT(fill_ones_u16); +HWY_EXPORT(fill_ones_i8); +HWY_EXPORT(fill_ones_u8); + +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_value_fp32); +HWY_EXPORT(fill_value_fp64); +HWY_EXPORT(fill_value_i32); +HWY_EXPORT(fill_value_u32); +HWY_EXPORT(fill_value_i64); +HWY_EXPORT(fill_value_u64); +HWY_EXPORT(fill_value_i16); +HWY_EXPORT(fill_value_u16); +HWY_EXPORT(fill_value_i8); +HWY_EXPORT(fill_value_u8); + +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i8)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u8)(dst, n, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_arange_fp32); +HWY_EXPORT(fill_arange_i32); +HWY_EXPORT(fill_arange_u32); +HWY_EXPORT(fill_arange_i64); +HWY_EXPORT(fill_arange_u64); +HWY_EXPORT(fill_arange_i16); +HWY_EXPORT(fill_arange_u16); +HWY_EXPORT(fill_arange_i8); +HWY_EXPORT(fill_arange_u8); + +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_fp32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i8)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u8)(dst, n, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_random_fp32); +HWY_EXPORT(fill_random_i32); +HWY_EXPORT(fill_random_u32); +HWY_EXPORT(fill_random_i64); +HWY_EXPORT(fill_random_u64); +HWY_EXPORT(fill_random_i16); +HWY_EXPORT(fill_random_u16); +HWY_EXPORT(fill_random_i8); +HWY_EXPORT(fill_random_u8); + +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_fp32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i8)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u8)(dst, n, start, end, seed); +} + } // namespace mllm::cpu::common #endif // HWY_ONCE diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp index eb100ac43..4df34db0e 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp @@ -7,6 +7,7 @@ #include "mllm/utils/CPUArchHelper.hpp" #if !(defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)) +#include #include "mllm/core/DataTypes.hpp" // Platform-specific definitions used for declaring an interface, independent of @@ -30,6 +31,222 @@ HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value); +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value); +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value); +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value); +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value); +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value); +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value); +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value); +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value); +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value); + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); + +//===----------------------------------------------------------------------===// +// Template wrapper for generic fill operations +//===----------------------------------------------------------------------===// +template +inline void fill_zeros_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_zeros_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u8(dst, n); + } else { + // Fallback for unsupported types + std::memset(dst, 0, n * sizeof(T)); + } +} + +template +inline void fill_ones_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_ones_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u8(dst, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(1); } + } +} + +template +inline void fill_value_anytype(T* dst, size_t n, mllm_fp32_t value) { + if constexpr (std::is_same_v) { + call_fill_value_fp32(dst, n, value); + } else if constexpr (std::is_same_v) { + call_fill_value_fp64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i8(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u8(dst, n, static_cast(value)); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(value); } + } +} + +template +inline void fill_arange_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if constexpr (std::is_same_v) { + call_fill_arange_fp32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i8(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u8(dst, n, start, end, step); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +template +inline void fill_random_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + if constexpr (std::is_same_v) { + call_fill_random_fp32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i8(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u8(dst, n, start, end, seed); + } else { + // Fallback using LCG + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = end - start; + uint64_t state = seed; + for (size_t i = 0; i < n; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } + } +} + } // namespace mllm::cpu::common #endif diff --git a/mllm/backends/cpu/ops/FillOp.cpp b/mllm/backends/cpu/ops/FillOp.cpp index e4d935f51..cf5cee47e 100644 --- a/mllm/backends/cpu/ops/FillOp.cpp +++ b/mllm/backends/cpu/ops/FillOp.cpp @@ -21,7 +21,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_zeros(dst.ptr(), dst.numel(), threads); + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros(dst.ptr(), dst.numel(), threads); #endif @@ -29,7 +29,8 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + std::memset(dst.ptr(), 0, dst.numel() * sizeof(mllm_fp16_t)); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -37,7 +38,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -45,7 +46,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -53,7 +54,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -61,7 +62,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -69,7 +70,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -77,7 +78,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -85,7 +86,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -93,7 +94,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -110,7 +111,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_ones(dst.ptr(), dst.numel(), threads); + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones(dst.ptr(), dst.numel(), threads); #endif @@ -118,7 +119,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(1.0f); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -126,7 +129,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -134,7 +137,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -142,7 +145,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -150,7 +153,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -158,7 +161,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -166,7 +169,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -174,7 +177,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -182,7 +185,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -199,7 +202,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -207,7 +210,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.start + i * options_.step); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -215,7 +220,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -224,7 +229,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -233,7 +238,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -242,7 +247,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -251,7 +256,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -260,7 +265,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -269,7 +274,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -278,7 +283,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -295,7 +300,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -303,7 +308,18 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = options_.end - options_.start; + uint64_t state = options_.seed; + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + ptr[i] = static_cast(options_.start + random_value * range); + } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -311,7 +327,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -319,7 +335,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -327,7 +343,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -335,7 +351,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -343,7 +359,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -351,7 +367,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -359,7 +375,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -367,7 +383,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -383,7 +399,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -391,7 +407,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.value); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_fp16(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -399,7 +417,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -407,7 +425,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -415,7 +433,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -423,7 +441,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -431,7 +449,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -439,7 +457,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -447,7 +465,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -455,7 +473,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif diff --git a/mllm/backends/qnn/CMakeLists.txt b/mllm/backends/qnn/CMakeLists.txt index 0ad833792..83b4a43f9 100644 --- a/mllm/backends/qnn/CMakeLists.txt +++ b/mllm/backends/qnn/CMakeLists.txt @@ -44,3 +44,10 @@ get_property(current_includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INC message(STATUS "MLLM_QNN INCLUDES: ${current_includes}") #print include directories target_link_libraries(MllmQNNBackend PUBLIC MllmRT) + +install( + TARGETS MllmQNNBackend + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 90ee4ad72..957fdf321 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -369,8 +369,7 @@ bool LLMQuantRecipeNegPattern::isMatch(const mllm::ir::op_ptr_t& op) { } bool LLMQuantRecipeNegPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) { - return shareQuantSpecSingleInputToSingleOutputAndSetOpQuantAnnoAttr(writer.getContext(), - node->cast_()); + return noSharingSingleInAndSingleOutQuantAnnoAttr(writer.getContext(), node->cast_()); } //===----------------------------------------------------------------------===// @@ -651,8 +650,15 @@ bool LLMQuantRecipeConcatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr return false; } - MLLM_RETURN_FALSE_IF_NOT(i_0->getAttr("quant_recipe")); - MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); + // Create quant_recipe if not present + if (!i_0->getAttr("quant_recipe")) { + auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); + i_0->setAttr("quant_recipe", i_0_spec); + } + if (!i_1->getAttr("quant_recipe")) { + auto i_1_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_1->cast_()); + i_1->setAttr("quant_recipe", i_1_spec); + } o_0->setAttr("quant_recipe", i_0->getAttr("quant_recipe")); @@ -795,7 +801,8 @@ bool LLMQuantRecipeWherePattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_ MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); MLLM_RETURN_FALSE_IF_NOT(i_2->getAttr("quant_recipe")); - o_0->setAttr("quant_recipe", i_2->getAttr("quant_recipe")); + auto o_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), o_0->cast_()); + o_0->setAttr("quant_recipe", o_0_spec); auto annotation_attr = writer.create(); annotation_attr->annotation_.inputs.emplace_back( diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 1d42d58d3..0d34a51b2 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -300,6 +300,135 @@ void recursiveSolveNormal(const std::shared_ptr& ir_ctx, const ir }); } +void recursiveCheckUnsolved(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto linalg_op = op->cast_(); + std::string op_name = linalg_op->getAOp()->getName(); + + auto inputs = op->inputs(); + auto outputs = op->outputs(); + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, used by Op: '{}'", tv->name(), op_name); + } + } + + for (auto ooo : outputs) { + if (!ooo->isa_()) continue; + auto tv = ooo->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, produced by Op: '{}'", tv->name(), op_name); + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckUnsolved(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + +void recursiveCheckConcatInputs(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto concat_op = op->cast_(); + std::string op_name = concat_op->getAOp()->getName(); + + auto inputs = op->inputs(); + if (inputs.empty()) { return ir::IRWriter::WALK_CONTINUE; } + + // Get first input's scale and zero_point as reference + Tensor ref_scale; + Tensor ref_zero_point; + bool has_ref = false; + std::string ref_input_name; + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + + if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kAsymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_zero_point = this_spec->zero_point; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale and zero_point match + auto cur_scale = this_spec->scale; + auto cur_zero_point = this_spec->zero_point; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(ref_zero_point.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_zero_point.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + auto ref_zp_v = ref_zero_point.item(); + auto cur_zp_v = cur_zero_point.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6 || ref_zp_v != cur_zp_v) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale/zp between inputs. " + "Input '{}': scale={}, zp={}, scale_name={}, zp_name={}; Input '{}': scale={}, zp={}, scale_name={}, " + "zp_name={}", + op_name, ref_input_name, ref_scale_v, ref_zp_v, ref_scale.name(), ref_zero_point.name(), tv->name(), + cur_scale_v, cur_zp_v, cur_scale.name(), cur_zero_point.name()); + } + } + } else if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kSymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale matches + auto cur_scale = this_spec->scale; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale between inputs. " + "Input '{}': scale={}; Input '{}': scale={}", + op_name, ref_input_name, ref_scale_v, tv->name(), cur_scale_v); + } + } + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckConcatInputs(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + } // namespace uint8_t PTQPass::run(const ir::node_ptr_t& op) { @@ -330,6 +459,16 @@ uint8_t PTQPass::run(const ir::node_ptr_t& op) { getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_(), pf); + // Check for unsolved tensorValues and warn + recursiveCheckUnsolved( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + + // Check Concat inputs have consistent scale and zero_point + recursiveCheckConcatInputs( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + return ir::PASS_RET_SUCCESS; } diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp index 27f72e2e2..351e2562a 100644 --- a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp +++ b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp @@ -47,9 +47,12 @@ bool QnnAOTRMSNormPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) auto bias_tensor = mllm::Tensor::zeros(weight->tensor_.shape(), weight->tensor_.dtype()); auto bias_node = ir::tensor::TensorValue::build(writer.getContext().get(), bias_tensor); bias_node->tensor_.setName(a->getName() + "_runtime_bias"); + bias_node->name() = a->getName() + "_runtime_bias"; // fake bias quant recipe - auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(0, 0, kInt32, kFloat32, Tensor::ones({1})); + auto bias_scale = Tensor::ones({1}); + bias_scale.at({0}) = 1.0 / 32767; + auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(-32768, 32767, kInt16, kFloat32, bias_scale); auto quant_attr = mllm::ir::linalg::LinalgIRQuantizatonSpecAttr::build(writer.getContext().get(), quant_spec); bias_node->setAttr("quant_recipe", quant_attr); diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 22449f883..cb999191d 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -53,9 +53,25 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("mllm.cpu_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCPU); }); refl::GlobalDef().def("mllm.cuda_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCUDA); }); refl::GlobalDef().def("mllm.qnn_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kQNN); }); + // Floating point types refl::GlobalDef().def("mllm.float32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat32); }); refl::GlobalDef().def("mllm.float16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat16); }); refl::GlobalDef().def("mllm.bfloat16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kBFloat16); }); + + // Signed integer types + refl::GlobalDef().def("mllm.int8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt8); }); + refl::GlobalDef().def("mllm.int16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt16); }); + refl::GlobalDef().def("mllm.int32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt32); }); + refl::GlobalDef().def("mllm.int64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt64); }); + + // Unsigned integer types + refl::GlobalDef().def("mllm.uint8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); + refl::GlobalDef().def("mllm.uint16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt16); }); + refl::GlobalDef().def("mllm.uint32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt32); }); + refl::GlobalDef().def("mllm.uint64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt64); }); + + // Bool type + refl::GlobalDef().def("mllm.bool_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); } //===----------------------------------------------------------------------===// diff --git a/pymllm/__init__.py b/pymllm/__init__.py index 66240b714..1bd31cd6c 100644 --- a/pymllm/__init__.py +++ b/pymllm/__init__.py @@ -12,12 +12,27 @@ from . import service from . import backends from .ffi import ( + # Floating point types float32, float16, bfloat16, + # Signed integer types + int8, + int16, + int32, + int64, + # Unsigned integer types + uint8, + uint16, + uint32, + uint64, + # Bool type + boolean, + # Devices cpu, cuda, qnn, + # Tensor and utilities Tensor, empty, echo, @@ -26,7 +41,6 @@ is_numpy_available, from_torch, from_numpy, - empty, zeros, ones, arange, diff --git a/pymllm/backends/qualcomm/transformers/core/observer.py b/pymllm/backends/qualcomm/transformers/core/observer.py new file mode 100644 index 000000000..67a946b10 --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/core/observer.py @@ -0,0 +1,56 @@ +import torch +from torchao.quantization.pt2e import UniformQuantizationObserverBase + + +class ConcatObserver(UniformQuantizationObserverBase): + """ + Fetch maximum data range of all tensors to be concatenated + """ + + def __init__( + self, + dtype=torch.uint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + # get concat node and its inputs + self.input_observers = [] + + def add_observer(self, observer): + self.input_observers.append(observer) + + def forward(self, x_orig): + # calculate the min / max first + self.min_val = min(self.min_val, x_orig.min()) + self.max_val = max(self.max_val, x_orig.max()) + + # update min / max for all observers of input nodes + for observers in self.input_observers: + observers.min_val = self.min_val + observers.max_val = self.max_val + + return x_orig + + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index ce67729f4..c13011a51 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -1,6 +1,13 @@ import torch import torch.nn as nn -from torch.ao.quantization import FakeQuantize, MinMaxObserver +from torch.ao.quantization import ( + FakeQuantize, + MinMaxObserver, +) +from torch.ao.quantization.observer import FixedQParamsObserver + +DEFAULT_EPS_8BIT = 0.0001 / 255 +DEFAULT_EPS_16BIT = 0.0001 / 65535 class ActivationQDQ(nn.Module): @@ -30,16 +37,24 @@ def __init__(self, bits=8, qscheme=torch.per_tensor_affine): self.quant_min = 0 self.quant_max = (2**bits) - 1 + if bits == 8: + eps = DEFAULT_EPS_8BIT + elif bits == 16: + eps = DEFAULT_EPS_16BIT + else: + raise ValueError(f"Unsupported bit width: {bits}") + # 2. Initialize FakeQuantize - # MinMaxObserver calculates scale and zero_point based on observed tensors. + # MovingAverageMinMaxObserver calculates scale and zero_point based on observed tensors. # Passing quant_min/max to the observer ensures consistency. self.fake_quant = FakeQuantize( observer=MinMaxObserver.with_args( - qscheme=self.qscheme, dtype=self.dtype, + qscheme=self.qscheme, quant_min=self.quant_min, quant_max=self.quant_max, reduce_range=False, + eps=eps, ), quant_min=self.quant_min, quant_max=self.quant_max, @@ -63,12 +78,106 @@ def disable_observer(self): def enable_fakequant(self): """Enable simulation of quantization error.""" - self.fake_quant.enable_fakequant() + self.fake_quant.enable_fake_quant() def disable_fakequant(self): """Disable quantization simulation (act as identity).""" - self.fake_quant.disable_fakequant() + self.fake_quant.disable_fake_quant() def extra_repr(self): mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" return f"bits={self.bits}, mode={mode}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" + + +class FixedActivationQDQ(nn.Module): + """ + Fixed activation Quantization-DeQuantization (QDQ) module. + Uses pre-determined scale and zero_point instead of dynamic observation. + Supports both Symmetric and Asymmetric (Affine) quantization. + Uses torch.qint32 as a unified type to support various bit-widths. + """ + + def __init__(self, scale, zero_point, bits=8, qscheme=torch.per_tensor_affine): + super().__init__() + self.bits = bits + self.qscheme = qscheme + + # Define the simulation dtype as qint32 to avoid overflow across different bit-widths + self.dtype = torch.qint32 + + # 1. Calculate quantization range based on bits and scheme + if qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]: + # Symmetric: range is [-(2^(bits-1)), 2^(bits-1) - 1] + # e.g., 8-bit: -128 to 127 + self.quant_min = -(2 ** (bits - 1)) + self.quant_max = 2 ** (bits - 1) - 1 + else: + # Asymmetric (Affine): range is [0, 2^bits - 1] + # e.g., 8-bit: 0 to 255 + self.quant_min = 0 + self.quant_max = (2**bits) - 1 + + if bits not in [8, 16]: + raise ValueError(f"Unsupported bit width: {bits}") + + # 2. Convert scale and zero_point to tensors if needed + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float32) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.int32) + + # 3. Initialize FakeQuantize with fixed parameters + # Use FakeQuantize with FixedQParamsObserver for fixed scale and zero_point + self.fake_quant = FakeQuantize.with_args( + observer=FixedQParamsObserver.with_args( + scale=scale, + zero_point=zero_point, + ), + dtype=self.dtype, + qscheme=self.qscheme, + quant_min=self.quant_min, + quant_max=self.quant_max, + )() + + def forward(self, x): + # Applies fake quantization with fixed scale and zero_point: + # rounds to nearest integer and clamps to [min, max], + # then dequantizes back to float to simulate quantization noise. + return self.fake_quant(x) + + # Control methods for quantization-aware training (QAT) + # Note: FixedActivationQDQ doesn't have observer, so these methods + # only control fake quantization behavior + def enable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def disable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def enable_fakequant(self): + """Enable simulation of quantization error.""" + self.fake_quant.enable_fake_quant() + + def disable_fakequant(self): + """Disable quantization simulation (act as identity).""" + self.fake_quant.disable_fake_quant() + + @property + def scale(self): + """Get the fixed scale value.""" + return self.fake_quant.scale + + @property + def zero_point(self): + """Get the fixed zero_point value.""" + return self.fake_quant.zero_point + + def extra_repr(self): + mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" + scale_val = self.scale.item() if self.scale.numel() == 1 else self.scale + zp_val = ( + self.zero_point.item() if self.zero_point.numel() == 1 else self.zero_point + ) + return f"bits={self.bits}, mode={mode}, scale={scale_val}, zero_point={zp_val}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/backends/qualcomm/transformers/core/qlinear.py index d9c55e759..255f52ffb 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/backends/qualcomm/transformers/core/qlinear.py @@ -296,7 +296,9 @@ def convert_to_conv2d_deploy_hwio(self): s1_permuted = ( s1.view(self.out_features, -1).t().contiguous() ) # [Out, Blocks] -> [Blocks, Out] - s1_hwio = s1_permuted.view(1, 1, -1, self.out_features) # Shape: [1, 1, Blocks, Out] + s1_hwio = s1_permuted.view( + 1, 1, -1, self.out_features + ) # Shape: [1, 1, Blocks, Out] del self.weight self.register_buffer("weight", w_hwio) diff --git a/pymllm/backends/qualcomm/transformers/core/rms_norm.py b/pymllm/backends/qualcomm/transformers/core/rms_norm.py index 0101d6aee..b3964469f 100644 --- a/pymllm/backends/qualcomm/transformers/core/rms_norm.py +++ b/pymllm/backends/qualcomm/transformers/core/rms_norm.py @@ -21,7 +21,9 @@ def __init__( # Quantization configuration for Weight self.weight_fake_quant = FakeQuantize( observer=MinMaxObserver.with_args( - qscheme=torch.per_tensor_affine, dtype=torch.qint32 + qscheme=torch.per_tensor_affine, + dtype=torch.qint32, + eps=0.0001 / 65535, ), quant_min=0, quant_max=2 ** (quant_bits) - 1, diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 9c0696328..92efaa06d 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -49,9 +49,12 @@ from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, - QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen3MLP(nn.Module): @@ -76,7 +79,12 @@ def __init__(self, config): self.gate_proj_output_qdq = ActivationQDQ(bits=16) self.act_output_qdq = ActivationQDQ(bits=16) self.down_proj_input_qdq = ActivationQDQ(bits=16) - self.sigmoid_output_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) def forward(self, x): x = self.up_proj_input_qdq(x) @@ -93,11 +101,13 @@ def forward(self, x): return o -def rotate_half(x): +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -207,6 +217,39 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_norm_output_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_norm_output_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + # In qnn, is uint8 sym. self.k_cast_to_int8_qdq = ActivationQDQ( bits=8, qscheme=torch.per_tensor_symmetric @@ -224,6 +267,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.minus_0_output_qdq = ActivationQDQ(bits=16) self.softmax_output_qdq = ActivationQDQ(bits=16) self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -256,11 +300,27 @@ def forward( sin = sin.unsqueeze(1) query_states = self.q_rope_add_0_output_qdq( self.q_rope_mul_0_output_qdq(query_states * cos) - + self.q_rope_mul_1_output_qdq(rotate_half(query_states) * sin) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_norm_output_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_rope_add_0_output_qdq( self.k_rope_mul_0_output_qdq(key_states * cos) - + self.k_rope_mul_1_output_qdq(rotate_half(key_states) * sin) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_norm_output_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_cast_to_int8_qdq(key_states) @@ -281,7 +341,7 @@ def forward( torch.matmul(query_states, key_states.transpose(2, 3)) ) * self.scaling_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) * self.scaling ) ) @@ -292,10 +352,13 @@ def forward( attn_vv = self.minus_0_output_qdq( attn_min + self.neg_20_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) * (-20) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) ) ) - attn_weights = torch.where(attention_mask == 0, attn_weights, attn_vv) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) attn_weights = self.softmax_output_qdq( nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( @@ -315,6 +378,7 @@ def forward( class Qwen3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() + self.layer_dix = layer_idx self.hidden_size = config.hidden_size self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) @@ -362,6 +426,7 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) + hidden_states = self.add_0_output_qdq( residual + self.add_0_lhs_input_qdq(hidden_states) ) @@ -567,6 +632,12 @@ def forward( self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( hidden_states, max_position_ids ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( self.mllm_max_cos_embedding ) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 53ab40a9e..6565ca7e6 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -2,7 +2,10 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, @@ -31,6 +34,16 @@ def enable_qdq_observer(m): m.enable_observer() +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + + def convert_weight(m): if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): m.convert_to_conv2d_deploy_hwio() @@ -44,6 +57,7 @@ def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.model = Qwen3ForCausalLM.from_pretrained( model_path, attn_implementation="eager", + dtype=torch.float32, ) self.model.cuda() self.mllm_qualcomm_max_length = mllm_qualcomm_max_length @@ -60,6 +74,12 @@ def freeze_activation(self): def enable_activation_update(self): self.model.apply(enable_qdq_observer) + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + def compile(self): print("Compile Start.") self.model = torch.compile( diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/backends/qualcomm/transformers/qwen3/train.py index 13ad2785a..25361f372 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/train.py @@ -37,13 +37,17 @@ def main(): args = parser.parse_args() m = Qwen3Quantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) - # m.compile() + m.enable_fake_quant() m.infer(args.infer_text) # !!! # Things below is for deploy. We will turn all fp32 weights and some buffers(rope) to quantized dtype. # !!! + # This line maybe error. we need use quantized weight!!! not embed_tokens.weight!!! m.model.lm_head.weight = torch.nn.Parameter( m.model.model.embed_tokens.weight.clone() ) diff --git a/pymllm/convertor/model_file_v2.py b/pymllm/convertor/model_file_v2.py index 302e3e21b..976c04411 100644 --- a/pymllm/convertor/model_file_v2.py +++ b/pymllm/convertor/model_file_v2.py @@ -24,6 +24,14 @@ MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH = 16 +def _torch_tensor_bytes(tensor: "torch.Tensor") -> bytes: + # Use uint8 view to preserve raw bytes for dtypes not supported by numpy. + t = tensor.detach().cpu().contiguous() + if t.dim() == 0: + t = t.reshape(1) + return t.view(torch.uint8).numpy().tobytes() + + class ModelFileV2Descriptor: SIZE = 532 @@ -132,7 +140,7 @@ def streaming_write(self, tensor_name, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor_obj, torch.Tensor): # PyTorch tensor shape = list(tensor_obj.shape) - tensor_data = tensor_obj.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor_obj) true_dtype = MLLM_TYPE_MAPPING[tensor_obj.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor_obj, np.ndarray): # Numpy array @@ -203,7 +211,7 @@ def static_write(self, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor, torch.Tensor): # PyTorch tensor shape = list(tensor.shape) - tensor_data = tensor.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor) true_dtype = MLLM_TYPE_MAPPING[tensor.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor, np.ndarray): # Numpy array diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index 17bd04c19..9780eabb0 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -48,6 +48,10 @@ def to_pod(self) -> int: return tvm_ffi.get_global_func("mllm.DType.to_pod")(self) +# ============================================================================= +# DType factory functions +# ============================================================================= +# Floating point types def float32_() -> DType: return _ffi_api.float32_() @@ -60,6 +64,45 @@ def bfloat16_() -> DType: return _ffi_api.bfloat16_() +# Signed integer types +def int8_() -> DType: + return _ffi_api.int8_() + + +def int16_() -> DType: + return _ffi_api.int16_() + + +def int32_() -> DType: + return _ffi_api.int32_() + + +def int64_() -> DType: + return _ffi_api.int64_() + + +# Unsigned integer types +def uint8_() -> DType: + return _ffi_api.uint8_() + + +def uint16_() -> DType: + return _ffi_api.uint16_() + + +def uint32_() -> DType: + return _ffi_api.uint32_() + + +def uint64_() -> DType: + return _ffi_api.uint64_() + + +# Bool type (backed by uint8) +def bool_() -> DType: + return _ffi_api.bool_() + + def cpu_() -> Device: return _ffi_api.cpu_() @@ -219,10 +262,32 @@ def is_contiguous(self): return tvm_ffi.get_global_func("mllm.Tensor.is_contiguous")(self) -# Global dtypes +# ============================================================================= +# Global dtype instances +# ============================================================================= +# Floating point types float32: DType = float32_() float16: DType = float16_() bfloat16: DType = bfloat16_() + +# Signed integer types +int8: DType = int8_() +int16: DType = int16_() +int32: DType = int32_() +int64: DType = int64_() + +# Unsigned integer types +uint8: DType = uint8_() +uint16: DType = uint16_() +uint32: DType = uint32_() +uint64: DType = uint64_() + +# Bool type (use 'boolean' to avoid shadowing Python's built-in 'bool') +boolean: DType = bool_() + +# ============================================================================= +# Global device instances +# ============================================================================= cpu: Device = cpu_() cuda: Device = cuda_() qnn: Device = qnn_()