From fcf6af72d23f5ec7d99e12f241671056244b57ee Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sat, 17 Jan 2026 02:31:56 +0000 Subject: [PATCH 1/5] fix: Enhance quantization modules. Introduced FixedActivationQDQ for fixed quantization parameters, updated ActivationQDQ to use MovingAverageMinMaxObserver, and adjusted eps values for better precision. Modified Qwen3 model to utilize FixedActivationQDQ for sigmoid output and ensured dtype consistency in attention calculations. --- .../qualcomm/transformers/core/qdq.py | 117 +++++++++++++++++- .../qualcomm/transformers/core/rms_norm.py | 4 +- .../transformers/qwen3/modeling_qwen3.py | 34 ++++- .../qualcomm/transformers/qwen3/runner.py | 1 + .../qualcomm/transformers/qwen3/train.py | 1 + 5 files changed, 147 insertions(+), 10 deletions(-) diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index ce67729f4..8a4f90687 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -1,6 +1,13 @@ import torch import torch.nn as nn -from torch.ao.quantization import FakeQuantize, MinMaxObserver +from torch.ao.quantization import ( + FakeQuantize, + MovingAverageMinMaxObserver, +) +from torch.ao.quantization.observer import FixedQParamsObserver + +DEFAULT_EPS_8BIT = 0.0001 / 255 +DEFAULT_EPS_16BIT = 0.0001 / 65535 class ActivationQDQ(nn.Module): @@ -30,16 +37,24 @@ def __init__(self, bits=8, qscheme=torch.per_tensor_affine): self.quant_min = 0 self.quant_max = (2**bits) - 1 + if bits == 8: + eps = DEFAULT_EPS_8BIT + elif bits == 16: + eps = DEFAULT_EPS_16BIT + else: + raise ValueError(f"Unsupported bit width: {bits}") + # 2. Initialize FakeQuantize - # MinMaxObserver calculates scale and zero_point based on observed tensors. + # MovingAverageMinMaxObserver calculates scale and zero_point based on observed tensors. # Passing quant_min/max to the observer ensures consistency. self.fake_quant = FakeQuantize( - observer=MinMaxObserver.with_args( - qscheme=self.qscheme, + observer=MovingAverageMinMaxObserver.with_args( dtype=self.dtype, + qscheme=self.qscheme, quant_min=self.quant_min, quant_max=self.quant_max, reduce_range=False, + eps=eps, ), quant_min=self.quant_min, quant_max=self.quant_max, @@ -72,3 +87,97 @@ def disable_fakequant(self): def extra_repr(self): mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" return f"bits={self.bits}, mode={mode}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" + + +class FixedActivationQDQ(nn.Module): + """ + Fixed activation Quantization-DeQuantization (QDQ) module. + Uses pre-determined scale and zero_point instead of dynamic observation. + Supports both Symmetric and Asymmetric (Affine) quantization. + Uses torch.qint32 as a unified type to support various bit-widths. + """ + + def __init__(self, scale, zero_point, bits=8, qscheme=torch.per_tensor_affine): + super().__init__() + self.bits = bits + self.qscheme = qscheme + + # Define the simulation dtype as qint32 to avoid overflow across different bit-widths + self.dtype = torch.qint32 + + # 1. Calculate quantization range based on bits and scheme + if qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]: + # Symmetric: range is [-(2^(bits-1)), 2^(bits-1) - 1] + # e.g., 8-bit: -128 to 127 + self.quant_min = -(2 ** (bits - 1)) + self.quant_max = 2 ** (bits - 1) - 1 + else: + # Asymmetric (Affine): range is [0, 2^bits - 1] + # e.g., 8-bit: 0 to 255 + self.quant_min = 0 + self.quant_max = (2**bits) - 1 + + if bits not in [8, 16]: + raise ValueError(f"Unsupported bit width: {bits}") + + # 2. Convert scale and zero_point to tensors if needed + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, dtype=torch.float32) + if not isinstance(zero_point, torch.Tensor): + zero_point = torch.tensor(zero_point, dtype=torch.int32) + + # 3. Initialize FakeQuantize with fixed parameters + # Use FakeQuantize with FixedQParamsObserver for fixed scale and zero_point + self.fake_quant = FakeQuantize.with_args( + observer=FixedQParamsObserver.with_args( + scale=scale, + zero_point=zero_point, + ), + dtype=self.dtype, + qscheme=self.qscheme, + quant_min=self.quant_min, + quant_max=self.quant_max, + )() + + def forward(self, x): + # Applies fake quantization with fixed scale and zero_point: + # rounds to nearest integer and clamps to [min, max], + # then dequantizes back to float to simulate quantization noise. + return self.fake_quant(x) + + # Control methods for quantization-aware training (QAT) + # Note: FixedActivationQDQ doesn't have observer, so these methods + # only control fake quantization behavior + def enable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def disable_observer(self): + """No-op: FixedActivationQDQ doesn't use observer.""" + pass + + def enable_fakequant(self): + """Enable simulation of quantization error.""" + self.fake_quant.enable_fakequant() + + def disable_fakequant(self): + """Disable quantization simulation (act as identity).""" + self.fake_quant.disable_fakequant() + + @property + def scale(self): + """Get the fixed scale value.""" + return self.fake_quant.scale + + @property + def zero_point(self): + """Get the fixed zero_point value.""" + return self.fake_quant.zero_point + + def extra_repr(self): + mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" + scale_val = self.scale.item() if self.scale.numel() == 1 else self.scale + zp_val = ( + self.zero_point.item() if self.zero_point.numel() == 1 else self.zero_point + ) + return f"bits={self.bits}, mode={mode}, scale={scale_val}, zero_point={zp_val}, q_range=({self.quant_min}, {self.quant_max}), dtype={self.dtype}" diff --git a/pymllm/backends/qualcomm/transformers/core/rms_norm.py b/pymllm/backends/qualcomm/transformers/core/rms_norm.py index 0101d6aee..b3964469f 100644 --- a/pymllm/backends/qualcomm/transformers/core/rms_norm.py +++ b/pymllm/backends/qualcomm/transformers/core/rms_norm.py @@ -21,7 +21,9 @@ def __init__( # Quantization configuration for Weight self.weight_fake_quant = FakeQuantize( observer=MinMaxObserver.with_args( - qscheme=torch.per_tensor_affine, dtype=torch.qint32 + qscheme=torch.per_tensor_affine, + dtype=torch.qint32, + eps=0.0001 / 65535, ), quant_min=0, quant_max=2 ** (quant_bits) - 1, diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 9c0696328..0bbcbffd8 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -49,9 +49,11 @@ from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, - QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) class Qwen3MLP(nn.Module): @@ -76,7 +78,12 @@ def __init__(self, config): self.gate_proj_output_qdq = ActivationQDQ(bits=16) self.act_output_qdq = ActivationQDQ(bits=16) self.down_proj_input_qdq = ActivationQDQ(bits=16) - self.sigmoid_output_qdq = ActivationQDQ(bits=16) + # For sigmoid output: scale = 1 / (q_max - q_min + 1), zp = 0 + # For 16-bit: q_min = 0, q_max = 65535 + sigmoid_scale = 1.0 / (65535 - 0 + 1) # 1 / 65536 + self.sigmoid_output_qdq = FixedActivationQDQ( + scale=sigmoid_scale, zero_point=0, bits=16 + ) def forward(self, x): x = self.up_proj_input_qdq(x) @@ -281,7 +288,7 @@ def forward( torch.matmul(query_states, key_states.transpose(2, 3)) ) * self.scaling_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) * self.scaling ) ) @@ -292,7 +299,8 @@ def forward( attn_vv = self.minus_0_output_qdq( attn_min + self.neg_20_qdq( - torch.ones(1, dtype=torch.bfloat16, device=value_states.device) * (-20) + torch.ones(1, dtype=value_states.dtype, device=value_states.device) + * (-20) ) ) attn_weights = torch.where(attention_mask == 0, attn_weights, attn_vv) @@ -315,6 +323,7 @@ def forward( class Qwen3DecoderLayer(GradientCheckpointingLayer): def __init__(self, config: Qwen3Config, layer_idx: int): super().__init__() + self.layer_dix = layer_idx self.hidden_size = config.hidden_size self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx) @@ -362,6 +371,15 @@ def forward( position_embeddings=position_embeddings, **kwargs, ) + + if self.layer_dix == 2: + print("1", hidden_states.min(), hidden_states.max()) + print( + "2", + self.add_0_lhs_input_qdq(hidden_states).min(), + self.add_0_lhs_input_qdq(hidden_states).max(), + ) + hidden_states = self.add_0_output_qdq( residual + self.add_0_lhs_input_qdq(hidden_states) ) @@ -567,6 +585,12 @@ def forward( self.mllm_max_cos_embedding, self.mllm_max_sin_embedding = self.rotary_emb( hidden_states, max_position_ids ) + self.mllm_max_cos_embedding = self.mllm_max_cos_embedding.to( + inputs_embeds.dtype + ) + self.mllm_max_sin_embedding = self.mllm_max_sin_embedding.to( + inputs_embeds.dtype + ) self.mllm_max_cos_embedding = self.cos_embedding_input_qdq( self.mllm_max_cos_embedding ) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 53ab40a9e..88f5ce84e 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -44,6 +44,7 @@ def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.model = Qwen3ForCausalLM.from_pretrained( model_path, attn_implementation="eager", + dtype=torch.bfloat16, ) self.model.cuda() self.mllm_qualcomm_max_length = mllm_qualcomm_max_length diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/backends/qualcomm/transformers/qwen3/train.py index 13ad2785a..33351918f 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/train.py @@ -44,6 +44,7 @@ def main(): # !!! # Things below is for deploy. We will turn all fp32 weights and some buffers(rope) to quantized dtype. # !!! + # This line maybe error. we need use quantized weight!!! not embed_tokens.weight!!! m.model.lm_head.weight = torch.nn.Parameter( m.model.model.embed_tokens.weight.clone() ) From c111b645055f8e5694874bf1b48c7bdfc41c9e07 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Sat, 17 Jan 2026 03:29:56 +0000 Subject: [PATCH 2/5] fix: Suppress deprecated comma-subscript warnings in CMake and remove debug print statements from Qwen3DecoderLayer --- mllm/CMakeLists.txt | 4 ++++ .../qualcomm/transformers/qwen3/modeling_qwen3.py | 10 ++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/mllm/CMakeLists.txt b/mllm/CMakeLists.txt index 9df6b7741..fd796f95a 100644 --- a/mllm/CMakeLists.txt +++ b/mllm/CMakeLists.txt @@ -56,6 +56,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "App endif() endif() +# FIXME: @oreomaker Need to remove comma features in slice! +# Suppress comma-subscript warnings (deprecated C++ feature that will be removed in C++26) +target_compile_options(MllmRT PUBLIC -Wno-comma-subscript) + # ONLY APPLE CAN DO ! # Processing OpenMP if(MLLM_KERNEL_USE_THREADS AND MLLM_KERNEL_THREADS_VENDOR_OPENMP) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 0bbcbffd8..dc6486043 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -372,14 +372,6 @@ def forward( **kwargs, ) - if self.layer_dix == 2: - print("1", hidden_states.min(), hidden_states.max()) - print( - "2", - self.add_0_lhs_input_qdq(hidden_states).min(), - self.add_0_lhs_input_qdq(hidden_states).max(), - ) - hidden_states = self.add_0_output_qdq( residual + self.add_0_lhs_input_qdq(hidden_states) ) @@ -388,6 +380,8 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) + # if self.layer_dix == 2: + # print(hidden_states.min(), hidden_states.max()) hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) return hidden_states From fb471e50a23f5eaf7d28f185bbde2adb955a5c0b Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Mon, 19 Jan 2026 07:45:23 +0000 Subject: [PATCH 3/5] feat(qualcomm): Add installation targets for flatbuffers and MllmQNNBackend in CMake, enhance PTQPass with unsolved tensor value checks, and update quantization specifications in RMSNorm and model file conversion. --- CMakeLists.txt | 7 +++ .../qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp | 3 +- mllm/backends/qnn/CMakeLists.txt | 7 +++ mllm/backends/qnn/aot/passes/PTQPass.cpp | 44 +++++++++++++++++++ mllm/backends/qnn/aot/visitor/RMSNorm.cpp | 5 ++- .../qualcomm/transformers/core/qdq.py | 4 +- .../qualcomm/transformers/core/qlinear.py | 4 +- .../transformers/qwen3/modeling_qwen3.py | 2 - .../qualcomm/transformers/qwen3/runner.py | 2 +- pymllm/convertor/model_file_v2.py | 12 ++++- 10 files changed, 80 insertions(+), 10 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 298b412c0..fca470ee5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -332,6 +332,13 @@ install( ARCHIVE DESTINATION lib RUNTIME DESTINATION bin) +install( + TARGETS flatbuffers + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) + if(MLLM_BUILD_SDK_C_BINDING) install( TARGETS MllmSdkC diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index 9eed37267..f1b20a1a2 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -272,7 +272,8 @@ class Qwen3Attention final : public nn::Module { auto attn_min = ptq::QDQ(this, attn.min(-1, true), "reduce_min_output_qdq"); auto minus_value = Tensor::constant(-20, kFloat32); minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); - attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_min.addConstant(minus_value)); + auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); + attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_vv); attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); diff --git a/mllm/backends/qnn/CMakeLists.txt b/mllm/backends/qnn/CMakeLists.txt index 0ad833792..83b4a43f9 100644 --- a/mllm/backends/qnn/CMakeLists.txt +++ b/mllm/backends/qnn/CMakeLists.txt @@ -44,3 +44,10 @@ get_property(current_includes DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INC message(STATUS "MLLM_QNN INCLUDES: ${current_includes}") #print include directories target_link_libraries(MllmQNNBackend PUBLIC MllmRT) + +install( + TARGETS MllmQNNBackend + EXPORT MllmTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin) diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 1d42d58d3..7172db475 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -300,6 +300,45 @@ void recursiveSolveNormal(const std::shared_ptr& ir_ctx, const ir }); } +void recursiveCheckUnsolved(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto linalg_op = op->cast_(); + std::string op_name = linalg_op->getAOp()->getName(); + + auto inputs = op->inputs(); + auto outputs = op->outputs(); + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, used by Op: '{}'", tv->name(), op_name); + } + } + + for (auto ooo : outputs) { + if (!ooo->isa_()) continue; + auto tv = ooo->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + if (!f_spec->spec_->solved) { + MLLM_WARN("PTQPass: TensorValue '{}' is not solved, produced by Op: '{}'", tv->name(), op_name); + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckUnsolved(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + } // namespace uint8_t PTQPass::run(const ir::node_ptr_t& op) { @@ -330,6 +369,11 @@ uint8_t PTQPass::run(const ir::node_ptr_t& op) { getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_(), pf); + // Check for unsolved tensorValues and warn + recursiveCheckUnsolved( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + return ir::PASS_RET_SUCCESS; } diff --git a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp index 27f72e2e2..351e2562a 100644 --- a/mllm/backends/qnn/aot/visitor/RMSNorm.cpp +++ b/mllm/backends/qnn/aot/visitor/RMSNorm.cpp @@ -47,9 +47,12 @@ bool QnnAOTRMSNormPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& op) auto bias_tensor = mllm::Tensor::zeros(weight->tensor_.shape(), weight->tensor_.dtype()); auto bias_node = ir::tensor::TensorValue::build(writer.getContext().get(), bias_tensor); bias_node->tensor_.setName(a->getName() + "_runtime_bias"); + bias_node->name() = a->getName() + "_runtime_bias"; // fake bias quant recipe - auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(0, 0, kInt32, kFloat32, Tensor::ones({1})); + auto bias_scale = Tensor::ones({1}); + bias_scale.at({0}) = 1.0 / 32767; + auto quant_spec = mllm::ir::linalg::QuantizationSpecSymPerTensor::create(-32768, 32767, kInt16, kFloat32, bias_scale); auto quant_attr = mllm::ir::linalg::LinalgIRQuantizatonSpecAttr::build(writer.getContext().get(), quant_spec); bias_node->setAttr("quant_recipe", quant_attr); diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index 8a4f90687..f1c4d20dc 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -2,7 +2,7 @@ import torch.nn as nn from torch.ao.quantization import ( FakeQuantize, - MovingAverageMinMaxObserver, + MinMaxObserver, ) from torch.ao.quantization.observer import FixedQParamsObserver @@ -48,7 +48,7 @@ def __init__(self, bits=8, qscheme=torch.per_tensor_affine): # MovingAverageMinMaxObserver calculates scale and zero_point based on observed tensors. # Passing quant_min/max to the observer ensures consistency. self.fake_quant = FakeQuantize( - observer=MovingAverageMinMaxObserver.with_args( + observer=MinMaxObserver.with_args( dtype=self.dtype, qscheme=self.qscheme, quant_min=self.quant_min, diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/backends/qualcomm/transformers/core/qlinear.py index d9c55e759..255f52ffb 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/backends/qualcomm/transformers/core/qlinear.py @@ -296,7 +296,9 @@ def convert_to_conv2d_deploy_hwio(self): s1_permuted = ( s1.view(self.out_features, -1).t().contiguous() ) # [Out, Blocks] -> [Blocks, Out] - s1_hwio = s1_permuted.view(1, 1, -1, self.out_features) # Shape: [1, 1, Blocks, Out] + s1_hwio = s1_permuted.view( + 1, 1, -1, self.out_features + ) # Shape: [1, 1, Blocks, Out] del self.weight self.register_buffer("weight", w_hwio) diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index dc6486043..2f099088e 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -380,8 +380,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - # if self.layer_dix == 2: - # print(hidden_states.min(), hidden_states.max()) hidden_states = residual + self.add_1_lhs_input_qdq(hidden_states) return hidden_states diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index 88f5ce84e..ed302f215 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -44,7 +44,7 @@ def __init__(self, model_path: str, mllm_qualcomm_max_length=2048): self.model = Qwen3ForCausalLM.from_pretrained( model_path, attn_implementation="eager", - dtype=torch.bfloat16, + dtype=torch.float32, ) self.model.cuda() self.mllm_qualcomm_max_length = mllm_qualcomm_max_length diff --git a/pymllm/convertor/model_file_v2.py b/pymllm/convertor/model_file_v2.py index 302e3e21b..976c04411 100644 --- a/pymllm/convertor/model_file_v2.py +++ b/pymllm/convertor/model_file_v2.py @@ -24,6 +24,14 @@ MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH = 16 +def _torch_tensor_bytes(tensor: "torch.Tensor") -> bytes: + # Use uint8 view to preserve raw bytes for dtypes not supported by numpy. + t = tensor.detach().cpu().contiguous() + if t.dim() == 0: + t = t.reshape(1) + return t.view(torch.uint8).numpy().tobytes() + + class ModelFileV2Descriptor: SIZE = 532 @@ -132,7 +140,7 @@ def streaming_write(self, tensor_name, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor_obj, torch.Tensor): # PyTorch tensor shape = list(tensor_obj.shape) - tensor_data = tensor_obj.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor_obj) true_dtype = MLLM_TYPE_MAPPING[tensor_obj.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor_obj, np.ndarray): # Numpy array @@ -203,7 +211,7 @@ def static_write(self, tensor_obj): if MLLM_FIND_TORCH_AVAILABLE and isinstance(tensor, torch.Tensor): # PyTorch tensor shape = list(tensor.shape) - tensor_data = tensor.detach().cpu().numpy().tobytes() + tensor_data = _torch_tensor_bytes(tensor) true_dtype = MLLM_TYPE_MAPPING[tensor.dtype] elif MLLM_FIND_NUMPY_AVAILABLE and isinstance(tensor, np.ndarray): # Numpy array From 7f8f0f2ab69220e7e81e42a343c0271c1fc706c9 Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Mon, 19 Jan 2026 13:08:59 +0000 Subject: [PATCH 4/5] feat(qualcomm): Refactor Qwen3 model to integrate ConcatObserver for improved quantization, enhance rotate_half function to utilize observers, and ensure consistent scale and zero_point across concatenated inputs. --- .../qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp | 35 +++---- .../qnn/aot/passes/LLMQuantRecipePass.cpp | 17 +++- mllm/backends/qnn/aot/passes/PTQPass.cpp | 93 +++++++++++++++++++ .../qualcomm/transformers/core/observer.py | 56 +++++++++++ .../qualcomm/transformers/core/qdq.py | 8 +- .../transformers/qwen3/modeling_qwen3.py | 65 ++++++++++++- .../qualcomm/transformers/qwen3/runner.py | 21 ++++- .../qualcomm/transformers/qwen3/train.py | 5 +- 8 files changed, 268 insertions(+), 32 deletions(-) create mode 100644 pymllm/backends/qualcomm/transformers/core/observer.py diff --git a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp index f1b20a1a2..a2d054bad 100644 --- a/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp +++ b/examples/qwen3_qnn_aot/modeling_qwen_qnn_aot.hpp @@ -15,14 +15,6 @@ namespace mllm::models::qwen3 { -Tensor rotateHalf(Tensor x) { // NOLINT - // X is [x, x, x, D] - auto D = x.size(-1); - auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); - auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); - return nn::functional::concat({-x2, x1}, -1); -} - namespace ptq { Tensor QDQ(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch) { @@ -112,6 +104,14 @@ Tensor QDQ_ROPE(nn::Module* m, Tensor in, const std::string& qdq_name_in_pytorch } // namespace ptq +Tensor rotateHalf(Tensor x, nn::Module* m, const std::string& qdq_name_in_pytorch) { // NOLINT + // X is [x, x, x, D] + auto D = x.size(-1); + auto x1 = x.slice({kAll, kAll, kAll, {kAll, D / 2}}, /*ssa=*/true); + auto x2 = x.slice({kAll, kAll, kAll, {D / 2, kAll}}, /*ssa=*/true); + return nn::functional::concat({ptq::QDQ(m, -x2, qdq_name_in_pytorch), x1}, -1); +} + using vi32 = std::vector; #define CONV2D_PROPERTY vi32{1, 1}, vi32{1, 1}, vi32{0, 0}, vi32{1, 1}, false, aops::Conv2DOpImplType::kQNN_LPBQ_w4a16o16_G32 @@ -232,14 +232,16 @@ class Qwen3Attention final : public nn::Module { // [B, H, S, D] auto cos = llm_embedding_cos.unsqueeze(1); auto sin = llm_embedding_sin.unsqueeze(1); - query_states = ptq::QDQ(this, - ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(query_states) * sin, "q_rope_mul_1_output_qdq"), - "q_rope_add_0_output_qdq"); - key_states = ptq::QDQ(this, - ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") - + ptq::QDQ(this, rotateHalf(key_states) * sin, "k_rope_mul_1_output_qdq"), - "k_rope_add_0_output_qdq"); + query_states = + ptq::QDQ(this, + ptq::QDQ(this, query_states * cos, "q_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(query_states, this, "q_rope_neg_half_qdq") * sin, "q_rope_mul_1_output_qdq"), + "q_rope_add_0_output_qdq"); + key_states = + ptq::QDQ(this, + ptq::QDQ(this, key_states * cos, "k_rope_mul_0_output_qdq") + + ptq::QDQ(this, rotateHalf(key_states, this, "k_rope_neg_half_qdq") * sin, "k_rope_mul_1_output_qdq"), + "k_rope_add_0_output_qdq"); // De-quantization and quantization again key_states = key_states.to(kFloat32); @@ -274,6 +276,7 @@ class Qwen3Attention final : public nn::Module { minus_value = ptq::QDQ(this, minus_value, "neg_20_qdq"); auto attn_vv = ptq::QDQ(this, attn_min.addConstant(minus_value), "minus_0_output_qdq"); attn = nn::functional::where(causal_mask.equal(0.f), attn, attn_vv); + attn = ptq::QDQ(this, attn, "where_attn_qdq"); attn = ptq::QDQ(this, nn::functional::softmax(attn, -1), "softmax_output_qdq"); auto y = ptq::QDQ(this, nn::functional::matmul(attn, vh), "attn_value_matmul_output_qdq"); y = y.transpose(1, 2).view({1, 1, -1, num_attention_heads_ * head_dim_}, /*ssa=*/true); diff --git a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp index 90ee4ad72..957fdf321 100644 --- a/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp +++ b/mllm/backends/qnn/aot/passes/LLMQuantRecipePass.cpp @@ -369,8 +369,7 @@ bool LLMQuantRecipeNegPattern::isMatch(const mllm::ir::op_ptr_t& op) { } bool LLMQuantRecipeNegPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_t& node) { - return shareQuantSpecSingleInputToSingleOutputAndSetOpQuantAnnoAttr(writer.getContext(), - node->cast_()); + return noSharingSingleInAndSingleOutQuantAnnoAttr(writer.getContext(), node->cast_()); } //===----------------------------------------------------------------------===// @@ -651,8 +650,15 @@ bool LLMQuantRecipeConcatPattern::rewrite(ir::IRWriter& writer, const ir::op_ptr return false; } - MLLM_RETURN_FALSE_IF_NOT(i_0->getAttr("quant_recipe")); - MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); + // Create quant_recipe if not present + if (!i_0->getAttr("quant_recipe")) { + auto i_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_0->cast_()); + i_0->setAttr("quant_recipe", i_0_spec); + } + if (!i_1->getAttr("quant_recipe")) { + auto i_1_spec = genSimpleQuantizationSpecAttr(writer.getContext(), i_1->cast_()); + i_1->setAttr("quant_recipe", i_1_spec); + } o_0->setAttr("quant_recipe", i_0->getAttr("quant_recipe")); @@ -795,7 +801,8 @@ bool LLMQuantRecipeWherePattern::rewrite(ir::IRWriter& writer, const ir::op_ptr_ MLLM_RETURN_FALSE_IF_NOT(i_1->getAttr("quant_recipe")); MLLM_RETURN_FALSE_IF_NOT(i_2->getAttr("quant_recipe")); - o_0->setAttr("quant_recipe", i_2->getAttr("quant_recipe")); + auto o_0_spec = genSimpleQuantizationSpecAttr(writer.getContext(), o_0->cast_()); + o_0->setAttr("quant_recipe", o_0_spec); auto annotation_attr = writer.create(); annotation_attr->annotation_.inputs.emplace_back( diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 7172db475..82869ab16 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -339,6 +339,94 @@ void recursiveCheckUnsolved(const std::shared_ptr& ir_ctx, const }); } +void recursiveCheckConcatInputs(const std::shared_ptr& ir_ctx, const ir::graph::SubGraphOp::ptr_t& call_op) { + auto wow = ir::IRWriter(ir_ctx, call_op->getTopRegion()); + wow.walk([&](ir::IRWriter& w, const ir::Op::ptr_t& op) -> ir::IRWriter::WalkResult { + if (op->isa_()) { + auto concat_op = op->cast_(); + std::string op_name = concat_op->getAOp()->getName(); + + auto inputs = op->inputs(); + if (inputs.empty()) { return ir::IRWriter::WALK_CONTINUE; } + + // Get first input's scale and zero_point as reference + Tensor ref_scale; + Tensor ref_zero_point; + bool has_ref = false; + std::string ref_input_name; + + for (auto iii : inputs) { + if (!iii->isa_()) continue; + auto tv = iii->cast_(); + if (!tv->getAttr("quant_recipe")) continue; + auto f_spec = tv->getAttr("quant_recipe")->cast_(); + + if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kAsymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_zero_point = this_spec->zero_point; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale and zero_point match + auto cur_scale = this_spec->scale; + auto cur_zero_point = this_spec->zero_point; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(ref_zero_point.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_zero_point.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + auto ref_zp_v = ref_zero_point.item(); + auto cur_zp_v = cur_zero_point.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6 || ref_zp_v != cur_zp_v) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale/zp between inputs. " + "Input '{}': scale={}, zp={}; Input '{}': scale={}, zp={}", + op_name, ref_input_name, ref_scale_v, ref_zp_v, tv->name(), cur_scale_v, cur_zp_v); + } + } + } else if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kSymPerTensor) { + auto this_spec = std::static_pointer_cast(f_spec->spec_); + if (!this_spec->solved) continue; + + if (!has_ref) { + ref_scale = this_spec->scale; + ref_input_name = tv->name(); + has_ref = true; + } else { + // Check if scale matches + auto cur_scale = this_spec->scale; + + MLLM_RT_ASSERT_EQ(ref_scale.numel(), 1); + MLLM_RT_ASSERT_EQ(cur_scale.numel(), 1); + + auto ref_scale_v = ref_scale.item(); + auto cur_scale_v = cur_scale.item(); + + if (std::abs(ref_scale_v - cur_scale_v) > 1e-6) { + MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale between inputs. " + "Input '{}': scale={}; Input '{}': scale={}", + op_name, ref_input_name, ref_scale_v, tv->name(), cur_scale_v); + } + } + } + } + } + + if (op->isa_()) { + auto ns = op->cast_()->getSymbolAttr()->str(); + recursiveCheckConcatInputs(w.getContext(), w.getContext()->lookupSymbolTable(ns)->cast_()); + } + return ir::IRWriter::WALK_CONTINUE; + }); +} + } // namespace uint8_t PTQPass::run(const ir::node_ptr_t& op) { @@ -374,6 +462,11 @@ uint8_t PTQPass::run(const ir::node_ptr_t& op) { writer.getContext(), getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + // Check Concat inputs have consistent scale and zero_point + recursiveCheckConcatInputs( + writer.getContext(), + getCtx()->lookupSymbolTable(call_main_graph_op->getSymbolAttr()->str())->cast_()); + return ir::PASS_RET_SUCCESS; } diff --git a/pymllm/backends/qualcomm/transformers/core/observer.py b/pymllm/backends/qualcomm/transformers/core/observer.py new file mode 100644 index 000000000..67a946b10 --- /dev/null +++ b/pymllm/backends/qualcomm/transformers/core/observer.py @@ -0,0 +1,56 @@ +import torch +from torchao.quantization.pt2e import UniformQuantizationObserverBase + + +class ConcatObserver(UniformQuantizationObserverBase): + """ + Fetch maximum data range of all tensors to be concatenated + """ + + def __init__( + self, + dtype=torch.uint8, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=None, + quant_max=None, + factory_kwargs=None, + eps=torch.finfo(torch.float32).eps, # noqa: B008 + is_dynamic=False, + **kwargs, + ) -> None: + super().__init__( + dtype=dtype, + qscheme=qscheme, + reduce_range=reduce_range, + quant_min=quant_min, + quant_max=quant_max, + factory_kwargs=factory_kwargs, + eps=eps, + is_dynamic=is_dynamic, + **kwargs, + ) + + factory_kwargs = torch.nn.factory_kwargs(factory_kwargs) + self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs)) + self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs)) + # get concat node and its inputs + self.input_observers = [] + + def add_observer(self, observer): + self.input_observers.append(observer) + + def forward(self, x_orig): + # calculate the min / max first + self.min_val = min(self.min_val, x_orig.min()) + self.max_val = max(self.max_val, x_orig.max()) + + # update min / max for all observers of input nodes + for observers in self.input_observers: + observers.min_val = self.min_val + observers.max_val = self.max_val + + return x_orig + + def calculate_qparams(self): + return self._calculate_qparams(self.min_val, self.max_val) diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/backends/qualcomm/transformers/core/qdq.py index f1c4d20dc..c13011a51 100644 --- a/pymllm/backends/qualcomm/transformers/core/qdq.py +++ b/pymllm/backends/qualcomm/transformers/core/qdq.py @@ -78,11 +78,11 @@ def disable_observer(self): def enable_fakequant(self): """Enable simulation of quantization error.""" - self.fake_quant.enable_fakequant() + self.fake_quant.enable_fake_quant() def disable_fakequant(self): """Disable quantization simulation (act as identity).""" - self.fake_quant.disable_fakequant() + self.fake_quant.disable_fake_quant() def extra_repr(self): mode = "Symmetric" if "symmetric" in str(self.qscheme) else "Asymmetric" @@ -158,11 +158,11 @@ def disable_observer(self): def enable_fakequant(self): """Enable simulation of quantization error.""" - self.fake_quant.enable_fakequant() + self.fake_quant.enable_fake_quant() def disable_fakequant(self): """Disable quantization simulation (act as identity).""" - self.fake_quant.disable_fakequant() + self.fake_quant.disable_fake_quant() @property def scale(self): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 2f099088e..92efaa06d 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -54,6 +54,7 @@ ActivationQDQ, FixedActivationQDQ, ) +from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen3MLP(nn.Module): @@ -100,11 +101,13 @@ def forward(self, x): return o -def rotate_half(x): +def rotate_half( + x, x_observer, x2_neg_fake_quant: ActivationQDQ, concat_observer: ConcatObserver +): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) + return concat_observer(torch.cat((x2_neg_fake_quant(-x2), x1), dim=-1)) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -214,6 +217,39 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.k_rope_mul_1_output_qdq = ActivationQDQ(bits=16) self.k_rope_add_0_output_qdq = ActivationQDQ(bits=16) + self.q_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.q_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer = ConcatObserver( + dtype=torch.int32, + qscheme=torch.per_tensor_affine, + reduce_range=False, + quant_min=0, + quant_max=2**16 - 1, + eps=0.0001 / 65535, + is_dynamic=False, + ) + self.k_rope_neg_half_qdq = ActivationQDQ(bits=16) + self.k_rope_concat_observer.add_observer( + self.k_norm_output_qdq.fake_quant.activation_post_process + ) + self.k_rope_concat_observer.add_observer( + self.k_rope_neg_half_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_norm_output_qdq.fake_quant.activation_post_process + ) + self.q_rope_concat_observer.add_observer( + self.q_rope_neg_half_qdq.fake_quant.activation_post_process + ) + # In qnn, is uint8 sym. self.k_cast_to_int8_qdq = ActivationQDQ( bits=8, qscheme=torch.per_tensor_symmetric @@ -231,6 +267,7 @@ def __init__(self, config: Qwen3Config, layer_idx: int): self.minus_0_output_qdq = ActivationQDQ(bits=16) self.softmax_output_qdq = ActivationQDQ(bits=16) self.attn_value_matmul_output_qdq = ActivationQDQ(bits=16) + self.where_attn_qdq = ActivationQDQ(bits=16) @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( @@ -263,11 +300,27 @@ def forward( sin = sin.unsqueeze(1) query_states = self.q_rope_add_0_output_qdq( self.q_rope_mul_0_output_qdq(query_states * cos) - + self.q_rope_mul_1_output_qdq(rotate_half(query_states) * sin) + + self.q_rope_mul_1_output_qdq( + rotate_half( + query_states, + self.q_norm_output_qdq.fake_quant.activation_post_process, + self.q_rope_neg_half_qdq, + self.q_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_rope_add_0_output_qdq( self.k_rope_mul_0_output_qdq(key_states * cos) - + self.k_rope_mul_1_output_qdq(rotate_half(key_states) * sin) + + self.k_rope_mul_1_output_qdq( + rotate_half( + key_states, + self.k_norm_output_qdq.fake_quant.activation_post_process, + self.k_rope_neg_half_qdq, + self.k_rope_concat_observer, + ) + * sin + ) ) key_states = self.k_cast_to_int8_qdq(key_states) @@ -303,7 +356,9 @@ def forward( * (-20) ) ) - attn_weights = torch.where(attention_mask == 0, attn_weights, attn_vv) + attn_weights = self.where_attn_qdq( + torch.where(attention_mask == 0, attn_weights, attn_vv) + ) attn_weights = self.softmax_output_qdq( nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/backends/qualcomm/transformers/qwen3/runner.py index ed302f215..6565ca7e6 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/runner.py @@ -2,7 +2,10 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ActivationQDQ +from pymllm.backends.qualcomm.transformers.core.qdq import ( + ActivationQDQ, + FixedActivationQDQ, +) from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm from pymllm.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, @@ -31,6 +34,16 @@ def enable_qdq_observer(m): m.enable_observer() +def enable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.enable_fakequant() + + +def disable_fake_quant(m): + if isinstance(m, ActivationQDQ) or isinstance(m, FixedActivationQDQ): + m.disable_fakequant() + + def convert_weight(m): if isinstance(m, QLinearLPBQ) or isinstance(m, QLinearW8A16_PerChannelSym): m.convert_to_conv2d_deploy_hwio() @@ -61,6 +74,12 @@ def freeze_activation(self): def enable_activation_update(self): self.model.apply(enable_qdq_observer) + def enable_fake_quant(self): + self.model.apply(enable_fake_quant) + + def disable_fake_quant(self): + self.model.apply(disable_fake_quant) + def compile(self): print("Compile Start.") self.model = torch.compile( diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/backends/qualcomm/transformers/qwen3/train.py index 33351918f..25361f372 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/backends/qualcomm/transformers/qwen3/train.py @@ -37,8 +37,11 @@ def main(): args = parser.parse_args() m = Qwen3Quantizer(args.model_path, mllm_qualcomm_max_length=args.max_length) + + # FIXME: Should disable or not. + m.disable_fake_quant() m.calibrate(num_samples=args.num_samples, max_seq_length=args.max_length) - # m.compile() + m.enable_fake_quant() m.infer(args.infer_text) # !!! From 65591c4af590b0fcf181ab92abfdcf2bce8461fb Mon Sep 17 00:00:00 2001 From: chenghuaWang <2923277184@qq.com> Date: Tue, 20 Jan 2026 09:16:50 +0000 Subject: [PATCH 5/5] feat(cpu): Implement fill operations for various data types including zeros, ones, specific values, arange, and random fills. Introduce a new fill-inl.hpp file for optimized implementations and update kernel dispatch to include these operations. Enhance CPUFillOp to utilize the new fill functions for better performance and maintainability. --- mllm/backends/cpu/kernels/common/fill-inl.hpp | 363 ++++++++++++++++++ .../cpu/kernels/common/kernel_dispatch.cpp | 180 ++++++++- .../cpu/kernels/common/kernel_dispatch.hpp | 217 +++++++++++ mllm/backends/cpu/ops/FillOp.cpp | 118 +++--- mllm/backends/qnn/aot/passes/PTQPass.cpp | 6 +- mllm/ffi/Extension.cc | 16 + pymllm/__init__.py | 16 +- pymllm/ffi/__init__.py | 67 +++- 8 files changed, 928 insertions(+), 55 deletions(-) create mode 100644 mllm/backends/cpu/kernels/common/fill-inl.hpp diff --git a/mllm/backends/cpu/kernels/common/fill-inl.hpp b/mllm/backends/cpu/kernels/common/fill-inl.hpp new file mode 100644 index 000000000..4c799daf6 --- /dev/null +++ b/mllm/backends/cpu/kernels/common/fill-inl.hpp @@ -0,0 +1,363 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +// NOTE: Do NOT use #pragma once here! +// Highway's foreach_target.h mechanism requires -inl.hpp files to be included +// multiple times, once for each target architecture (AVX3_DL, AVX10_2, etc.). + +#include +#include +#include "mllm/core/DataTypes.hpp" + +HWY_BEFORE_NAMESPACE(); +namespace mllm::cpu::common { // NOLINT +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_zeros_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec zero = hn::Zero(d); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(zero, d, dst + idx); } + + if (idx < count) { hn::StoreN(zero, d, dst + idx, count - idx); } +} + +// Specialization for types not supported by Highway SIMD, use memset +template +HWY_INLINE void fill_zeros_scalar(T* HWY_RESTRICT dst, size_t count) { + if constexpr (std::is_trivial_v) { + std::memset(dst, 0, count * sizeof(T)); + } else { + T zero_val{}; + for (size_t i = 0; i < count; ++i) { dst[i] = zero_val; } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_zeros_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_zeros_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_ones_impl(T* HWY_RESTRICT dst, size_t count) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec one = hn::Set(d, static_cast(1)); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(one, d, dst + idx); } + + if (idx < count) { hn::StoreN(one, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_ones_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size) { + fill_ones_impl(dst, size); +} + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_value_impl(T* HWY_RESTRICT dst, size_t count, T value) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + const hn::Vec v = hn::Set(d, value); + size_t idx = 0; + + for (; idx + N <= count; idx += N) { hn::StoreU(v, d, dst + idx); } + + if (idx < count) { hn::StoreN(v, d, dst + idx, count - idx); } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_fp64(mllm_fp64_t* HWY_RESTRICT dst, size_t size, mllm_fp64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_int32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_uint32_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_int64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_uint64_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_int16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_uint16_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_int8_t value) { + fill_value_impl(dst, size, value); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_value_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_uint8_t value) { + fill_value_impl(dst, size, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange (start, end, step) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_arange_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if (step == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + // Calculate the actual number of elements to fill + size_t n = 0; + if ((step > 0 && start < end) || (step < 0 && start > end)) { + mllm_fp32_t n_float = (end - start) / step; + if (n_float > 0) { + n = static_cast(std::ceil(n_float)); + if (step > 0) { + if (start + (n - 1) * step >= end) --n; + } else { + if (start + (n - 1) * step <= end) --n; + } + n = std::min(n, count); + } + } + + // Use SIMD for float types where we can vectorize the computation + if constexpr (std::is_same_v) { + const hn::ScalableTag d; + const size_t N = hn::Lanes(d); + + // Create increment vector: [0, 1, 2, 3, ...] * step + const hn::Vec step_vec = hn::Set(d, step); + const hn::Vec n_step_vec = hn::Set(d, step * static_cast(N)); + + // Create base offsets [0, 1, 2, 3, ...] + hn::Vec base = hn::Iota(d, 0); + base = hn::Mul(base, step_vec); + hn::Vec current_start = hn::Add(hn::Set(d, start), base); + + size_t idx = 0; + for (; idx + N <= n; idx += N) { + hn::StoreU(current_start, d, dst + idx); + current_start = hn::Add(current_start, n_step_vec); + } + + // Handle remaining elements + for (; idx < n; ++idx) { dst[idx] = static_cast(start + idx * step); } + } else { + // Scalar fallback for other types + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_arange_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, mllm_fp32_t step) { + fill_arange_impl(dst, size, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random (using LCG random number generator) +//===----------------------------------------------------------------------===// +template +HWY_INLINE void fill_random_impl(T* HWY_RESTRICT dst, size_t count, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; // 2^31 + const mllm_fp32_t range = end - start; + + if (range == 0) { + fill_value_impl(dst, count, static_cast(start)); + return; + } + + uint64_t state = seed; + state = (multiplier * state + increment) % modulus; + + for (size_t i = 0; i < count; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_fp32(mllm_fp32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i32(mllm_int32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u32(mllm_uint32_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i64(mllm_int64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u64(mllm_uint64_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i16(mllm_int16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u16(mllm_uint16_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_i8(mllm_int8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +static HWY_NOINLINE HWY_MAYBE_UNUSED void fill_random_u8(mllm_uint8_t* HWY_RESTRICT dst, size_t size, mllm_fp32_t start, + mllm_fp32_t end, uint64_t seed) { + fill_random_impl(dst, size, start, end, seed); +} + +} // namespace HWY_NAMESPACE +} // namespace mllm::cpu::common +HWY_AFTER_NAMESPACE(); diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp index 1ad3cee93..7e81adfdf 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.cpp @@ -17,6 +17,7 @@ // Include all inline implementations here #include "mllm/backends/cpu/kernels/common/elewise-inl.hpp" +#include "mllm/backends/cpu/kernels/common/fill-inl.hpp" #if HWY_ONCE namespace mllm::cpu::common { @@ -69,11 +70,188 @@ HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 // GELU //===----------------------------------------------------------------------===// // HWY_EXPORT(gelu_fp32); -// +// // HWY_DLLEXPORT void call_gelu_fp32(mllm_fp32_t* out, const mllm_fp32_t* in, size_t n) { // HWY_DYNAMIC_DISPATCH(gelu_fp32)(out, in, n); // } +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_zeros_fp32); +HWY_EXPORT(fill_zeros_fp64); +HWY_EXPORT(fill_zeros_i32); +HWY_EXPORT(fill_zeros_u32); +HWY_EXPORT(fill_zeros_i64); +HWY_EXPORT(fill_zeros_u64); +HWY_EXPORT(fill_zeros_i16); +HWY_EXPORT(fill_zeros_u16); +HWY_EXPORT(fill_zeros_i8); +HWY_EXPORT(fill_zeros_u8); + +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_zeros_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_ones_fp32); +HWY_EXPORT(fill_ones_fp64); +HWY_EXPORT(fill_ones_i32); +HWY_EXPORT(fill_ones_u32); +HWY_EXPORT(fill_ones_i64); +HWY_EXPORT(fill_ones_u64); +HWY_EXPORT(fill_ones_i16); +HWY_EXPORT(fill_ones_u16); +HWY_EXPORT(fill_ones_i8); +HWY_EXPORT(fill_ones_u8); + +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_fp64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u32)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u64)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u16)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_i8)(dst, n); } +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n) { HWY_DYNAMIC_DISPATCH(fill_ones_u8)(dst, n); } + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_value_fp32); +HWY_EXPORT(fill_value_fp64); +HWY_EXPORT(fill_value_i32); +HWY_EXPORT(fill_value_u32); +HWY_EXPORT(fill_value_i64); +HWY_EXPORT(fill_value_u64); +HWY_EXPORT(fill_value_i16); +HWY_EXPORT(fill_value_u16); +HWY_EXPORT(fill_value_i8); +HWY_EXPORT(fill_value_u8); + +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_fp64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u32)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u64)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u16)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_i8)(dst, n, value); +} +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value) { + HWY_DYNAMIC_DISPATCH(fill_value_u8)(dst, n, value); +} + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_arange_fp32); +HWY_EXPORT(fill_arange_i32); +HWY_EXPORT(fill_arange_u32); +HWY_EXPORT(fill_arange_i64); +HWY_EXPORT(fill_arange_u64); +HWY_EXPORT(fill_arange_i16); +HWY_EXPORT(fill_arange_u16); +HWY_EXPORT(fill_arange_i8); +HWY_EXPORT(fill_arange_u8); + +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_fp32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u32)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u64)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u16)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_i8)(dst, n, start, end, step); +} +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + HWY_DYNAMIC_DISPATCH(fill_arange_u8)(dst, n, start, end, step); +} + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_EXPORT(fill_random_fp32); +HWY_EXPORT(fill_random_i32); +HWY_EXPORT(fill_random_u32); +HWY_EXPORT(fill_random_i64); +HWY_EXPORT(fill_random_u64); +HWY_EXPORT(fill_random_i16); +HWY_EXPORT(fill_random_u16); +HWY_EXPORT(fill_random_i8); +HWY_EXPORT(fill_random_u8); + +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_fp32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u32)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u64)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u16)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_i8)(dst, n, start, end, seed); +} +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + HWY_DYNAMIC_DISPATCH(fill_random_u8)(dst, n, start, end, seed); +} + } // namespace mllm::cpu::common #endif // HWY_ONCE diff --git a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp index eb100ac43..4df34db0e 100644 --- a/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp +++ b/mllm/backends/cpu/kernels/common/kernel_dispatch.hpp @@ -7,6 +7,7 @@ #include "mllm/utils/CPUArchHelper.hpp" #if !(defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM)) +#include #include "mllm/core/DataTypes.hpp" // Platform-specific definitions used for declaring an interface, independent of @@ -30,6 +31,222 @@ HWY_DLLEXPORT void call_elewise_sub_scalar_fp32(mllm_fp32_t* out, const mllm_fp3 HWY_DLLEXPORT void call_elewise_mul_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); HWY_DLLEXPORT void call_elewise_div_scalar_fp32(mllm_fp32_t* out, const mllm_fp32_t* x, mllm_fp32_t y, size_t n); +//===----------------------------------------------------------------------===// +// Fill Zeros +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_zeros_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_zeros_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Ones +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_ones_fp32(mllm_fp32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_fp64(mllm_fp64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i32(mllm_int32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u32(mllm_uint32_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i64(mllm_int64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u64(mllm_uint64_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i16(mllm_int16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u16(mllm_uint16_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_i8(mllm_int8_t* dst, size_t n); +HWY_DLLEXPORT void call_fill_ones_u8(mllm_uint8_t* dst, size_t n); + +//===----------------------------------------------------------------------===// +// Fill Specific Value +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_value_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t value); +HWY_DLLEXPORT void call_fill_value_fp64(mllm_fp64_t* dst, size_t n, mllm_fp64_t value); +HWY_DLLEXPORT void call_fill_value_i32(mllm_int32_t* dst, size_t n, mllm_int32_t value); +HWY_DLLEXPORT void call_fill_value_u32(mllm_uint32_t* dst, size_t n, mllm_uint32_t value); +HWY_DLLEXPORT void call_fill_value_i64(mllm_int64_t* dst, size_t n, mllm_int64_t value); +HWY_DLLEXPORT void call_fill_value_u64(mllm_uint64_t* dst, size_t n, mllm_uint64_t value); +HWY_DLLEXPORT void call_fill_value_i16(mllm_int16_t* dst, size_t n, mllm_int16_t value); +HWY_DLLEXPORT void call_fill_value_u16(mllm_uint16_t* dst, size_t n, mllm_uint16_t value); +HWY_DLLEXPORT void call_fill_value_i8(mllm_int8_t* dst, size_t n, mllm_int8_t value); +HWY_DLLEXPORT void call_fill_value_u8(mllm_uint8_t* dst, size_t n, mllm_uint8_t value); + +//===----------------------------------------------------------------------===// +// Fill Arange +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_arange_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); +HWY_DLLEXPORT void call_fill_arange_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step); + +//===----------------------------------------------------------------------===// +// Fill Random +//===----------------------------------------------------------------------===// +HWY_DLLEXPORT void call_fill_random_fp32(mllm_fp32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i32(mllm_int32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u32(mllm_uint32_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i64(mllm_int64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u64(mllm_uint64_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i16(mllm_int16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u16(mllm_uint16_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_i8(mllm_int8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); +HWY_DLLEXPORT void call_fill_random_u8(mllm_uint8_t* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed); + +//===----------------------------------------------------------------------===// +// Template wrapper for generic fill operations +//===----------------------------------------------------------------------===// +template +inline void fill_zeros_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_zeros_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_zeros_u8(dst, n); + } else { + // Fallback for unsupported types + std::memset(dst, 0, n * sizeof(T)); + } +} + +template +inline void fill_ones_anytype(T* dst, size_t n) { + if constexpr (std::is_same_v) { + call_fill_ones_fp32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_fp64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u32(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u64(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u16(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_i8(dst, n); + } else if constexpr (std::is_same_v) { + call_fill_ones_u8(dst, n); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(1); } + } +} + +template +inline void fill_value_anytype(T* dst, size_t n, mllm_fp32_t value) { + if constexpr (std::is_same_v) { + call_fill_value_fp32(dst, n, value); + } else if constexpr (std::is_same_v) { + call_fill_value_fp64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u32(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u64(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u16(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_i8(dst, n, static_cast(value)); + } else if constexpr (std::is_same_v) { + call_fill_value_u8(dst, n, static_cast(value)); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(value); } + } +} + +template +inline void fill_arange_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, mllm_fp32_t step) { + if constexpr (std::is_same_v) { + call_fill_arange_fp32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u32(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u64(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u16(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_i8(dst, n, start, end, step); + } else if constexpr (std::is_same_v) { + call_fill_arange_u8(dst, n, start, end, step); + } else { + // Fallback + for (size_t i = 0; i < n; ++i) { dst[i] = static_cast(start + i * step); } + } +} + +template +inline void fill_random_anytype(T* dst, size_t n, mllm_fp32_t start, mllm_fp32_t end, uint64_t seed) { + if constexpr (std::is_same_v) { + call_fill_random_fp32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u32(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u64(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u16(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_i8(dst, n, start, end, seed); + } else if constexpr (std::is_same_v) { + call_fill_random_u8(dst, n, start, end, seed); + } else { + // Fallback using LCG + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = end - start; + uint64_t state = seed; + for (size_t i = 0; i < n; ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + dst[i] = static_cast(start + random_value * range); + } + } +} + } // namespace mllm::cpu::common #endif diff --git a/mllm/backends/cpu/ops/FillOp.cpp b/mllm/backends/cpu/ops/FillOp.cpp index e4d935f51..cf5cee47e 100644 --- a/mllm/backends/cpu/ops/FillOp.cpp +++ b/mllm/backends/cpu/ops/FillOp.cpp @@ -21,7 +21,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_zeros(dst.ptr(), dst.numel(), threads); + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros(dst.ptr(), dst.numel(), threads); #endif @@ -29,7 +29,8 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + std::memset(dst.ptr(), 0, dst.numel() * sizeof(mllm_fp16_t)); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -37,7 +38,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -45,7 +46,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -53,7 +54,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -61,7 +62,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -69,7 +70,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -77,7 +78,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -85,7 +86,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -93,7 +94,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_zeros_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_zeros_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -110,7 +111,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_ones(dst.ptr(), dst.numel(), threads); + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones(dst.ptr(), dst.numel(), threads); #endif @@ -118,7 +119,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(1.0f); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_fp16(dst.ptr(), dst.numel(), threads); #endif @@ -126,7 +129,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -134,7 +137,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -142,7 +145,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -150,7 +153,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -158,7 +161,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -166,7 +169,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -174,7 +177,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -182,7 +185,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_ones_anytype(dst.ptr(), dst.numel()); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_ones_anytype(dst.ptr(), dst.numel(), threads); #endif @@ -199,7 +202,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -207,7 +210,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.start + i * options_.step); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); #endif @@ -215,7 +220,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -224,7 +229,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -233,7 +238,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -242,7 +247,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -251,7 +256,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -260,7 +265,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -269,7 +274,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -278,7 +283,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_arange_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.step, threads); @@ -295,7 +300,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -303,7 +308,18 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + const uint64_t multiplier = 1103515245ULL; + const uint64_t increment = 12345ULL; + const uint64_t modulus = 1ULL << 31; + const mllm_fp32_t range = options_.end - options_.start; + uint64_t state = options_.seed; + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { + state = (multiplier * state + increment) % modulus; + const mllm_fp32_t random_value = static_cast(state) / static_cast(modulus - 1); + ptr[i] = static_cast(options_.start + random_value * range); + } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_fp16(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -311,7 +327,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -319,7 +335,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -327,7 +343,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -335,7 +351,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -343,7 +359,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -351,7 +367,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -359,7 +375,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -367,7 +383,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_random_anytype(dst.ptr(), dst.numel(), options_.start, options_.end, options_.seed, threads); #endif @@ -383,7 +399,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& switch (dst.dtype()) { case kFloat32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - x86::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -391,7 +407,9 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kFloat16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + // FP16 not directly supported by Highway on x86, use scalar fallback + auto ptr = dst.ptr(); + for (size_t i = 0; i < dst.numel(); ++i) { ptr[i] = static_cast(options_.value); } #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_fp16(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -399,7 +417,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -407,7 +425,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -415,7 +433,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -423,7 +441,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -431,7 +449,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt64: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -439,7 +457,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt32: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -447,7 +465,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt16: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif @@ -455,7 +473,7 @@ void CPUFillOp::forward(const std::vector& inputs, std::vector& } case kUInt8: { #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) - // TODO + common::fill_value_anytype(dst.ptr(), dst.numel(), options_.value); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::fill_specific_value_anytype(dst.ptr(), dst.numel(), options_.value, threads); #endif diff --git a/mllm/backends/qnn/aot/passes/PTQPass.cpp b/mllm/backends/qnn/aot/passes/PTQPass.cpp index 82869ab16..0d34a51b2 100644 --- a/mllm/backends/qnn/aot/passes/PTQPass.cpp +++ b/mllm/backends/qnn/aot/passes/PTQPass.cpp @@ -387,8 +387,10 @@ void recursiveCheckConcatInputs(const std::shared_ptr& ir_ctx, co if (std::abs(ref_scale_v - cur_scale_v) > 1e-6 || ref_zp_v != cur_zp_v) { MLLM_ERROR("PTQPass: ConcatOp '{}' has mismatched scale/zp between inputs. " - "Input '{}': scale={}, zp={}; Input '{}': scale={}, zp={}", - op_name, ref_input_name, ref_scale_v, ref_zp_v, tv->name(), cur_scale_v, cur_zp_v); + "Input '{}': scale={}, zp={}, scale_name={}, zp_name={}; Input '{}': scale={}, zp={}, scale_name={}, " + "zp_name={}", + op_name, ref_input_name, ref_scale_v, ref_zp_v, ref_scale.name(), ref_zero_point.name(), tv->name(), + cur_scale_v, cur_zp_v, cur_scale.name(), cur_zero_point.name()); } } } else if (f_spec->spec_->type == ir::linalg::QuantizationSpecType::kSymPerTensor) { diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index 22449f883..cb999191d 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -53,9 +53,25 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("mllm.cpu_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCPU); }); refl::GlobalDef().def("mllm.cuda_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kCUDA); }); refl::GlobalDef().def("mllm.qnn_", []() -> mllm::ffi::Device { return mllm::ffi::Device(::mllm::DeviceTypes::kQNN); }); + // Floating point types refl::GlobalDef().def("mllm.float32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat32); }); refl::GlobalDef().def("mllm.float16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kFloat16); }); refl::GlobalDef().def("mllm.bfloat16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kBFloat16); }); + + // Signed integer types + refl::GlobalDef().def("mllm.int8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt8); }); + refl::GlobalDef().def("mllm.int16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt16); }); + refl::GlobalDef().def("mllm.int32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt32); }); + refl::GlobalDef().def("mllm.int64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kInt64); }); + + // Unsigned integer types + refl::GlobalDef().def("mllm.uint8_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); + refl::GlobalDef().def("mllm.uint16_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt16); }); + refl::GlobalDef().def("mllm.uint32_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt32); }); + refl::GlobalDef().def("mllm.uint64_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt64); }); + + // Bool type + refl::GlobalDef().def("mllm.bool_", []() -> mllm::ffi::DType { return mllm::ffi::DType(::mllm::DataTypes::kUInt8); }); } //===----------------------------------------------------------------------===// diff --git a/pymllm/__init__.py b/pymllm/__init__.py index 66240b714..1bd31cd6c 100644 --- a/pymllm/__init__.py +++ b/pymllm/__init__.py @@ -12,12 +12,27 @@ from . import service from . import backends from .ffi import ( + # Floating point types float32, float16, bfloat16, + # Signed integer types + int8, + int16, + int32, + int64, + # Unsigned integer types + uint8, + uint16, + uint32, + uint64, + # Bool type + boolean, + # Devices cpu, cuda, qnn, + # Tensor and utilities Tensor, empty, echo, @@ -26,7 +41,6 @@ is_numpy_available, from_torch, from_numpy, - empty, zeros, ones, arange, diff --git a/pymllm/ffi/__init__.py b/pymllm/ffi/__init__.py index 17bd04c19..9780eabb0 100644 --- a/pymllm/ffi/__init__.py +++ b/pymllm/ffi/__init__.py @@ -48,6 +48,10 @@ def to_pod(self) -> int: return tvm_ffi.get_global_func("mllm.DType.to_pod")(self) +# ============================================================================= +# DType factory functions +# ============================================================================= +# Floating point types def float32_() -> DType: return _ffi_api.float32_() @@ -60,6 +64,45 @@ def bfloat16_() -> DType: return _ffi_api.bfloat16_() +# Signed integer types +def int8_() -> DType: + return _ffi_api.int8_() + + +def int16_() -> DType: + return _ffi_api.int16_() + + +def int32_() -> DType: + return _ffi_api.int32_() + + +def int64_() -> DType: + return _ffi_api.int64_() + + +# Unsigned integer types +def uint8_() -> DType: + return _ffi_api.uint8_() + + +def uint16_() -> DType: + return _ffi_api.uint16_() + + +def uint32_() -> DType: + return _ffi_api.uint32_() + + +def uint64_() -> DType: + return _ffi_api.uint64_() + + +# Bool type (backed by uint8) +def bool_() -> DType: + return _ffi_api.bool_() + + def cpu_() -> Device: return _ffi_api.cpu_() @@ -219,10 +262,32 @@ def is_contiguous(self): return tvm_ffi.get_global_func("mllm.Tensor.is_contiguous")(self) -# Global dtypes +# ============================================================================= +# Global dtype instances +# ============================================================================= +# Floating point types float32: DType = float32_() float16: DType = float16_() bfloat16: DType = bfloat16_() + +# Signed integer types +int8: DType = int8_() +int16: DType = int16_() +int32: DType = int32_() +int64: DType = int64_() + +# Unsigned integer types +uint8: DType = uint8_() +uint16: DType = uint16_() +uint32: DType = uint32_() +uint64: DType = uint64_() + +# Bool type (use 'boolean' to avoid shadowing Python's built-in 'bool') +boolean: DType = bool_() + +# ============================================================================= +# Global device instances +# ============================================================================= cpu: Device = cpu_() cuda: Device = cuda_() qnn: Device = qnn_()