-
Notifications
You must be signed in to change notification settings - Fork 246
Fix QMoE blockwise quantization support for TRT-RTX execution provider #1926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
581a564
9f88bcd
9ae34f6
6d4ebca
60dfcdf
18edebb
73c67ed
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -359,19 +359,26 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | |
| int4_algo_config = self.make_int4_algo_config(extra_options.get("int4_algo_config", "default")) | ||
| self.int4_block_size = extra_options.get("int4_block_size", 32) | ||
|
|
||
| # Validate that only CPU and WebGPU EPs support int4_block_size for QMoE | ||
| if self.ep not in ["cpu", "webgpu"] and "int4_block_size" in extra_options and moe_op_type == "QMoE": | ||
| # CPU, WebGPU, and TRT-RTX support block-wise quantization for QMoE. | ||
| # TRT-RTX defaults to 128; CPU/WebGPU default to 0 (tensor-level) for backward compatibility. | ||
| supported_blockwise_eps = ["cpu", "webgpu", "trt-rtx", "NvTensorRtRtx"] | ||
| default_qmoe_block_size = 128 if self.ep in ["trt-rtx", "NvTensorRtRtx"] else 0 | ||
| self.int4_qmoe_block_size = extra_options.get("int4_qmoe_block_size", default_qmoe_block_size) | ||
|
|
||
| # Validate that unsupported EPs don't explicitly request block-wise quantization | ||
| if self.ep not in supported_blockwise_eps and "int4_qmoe_block_size" in extra_options and moe_op_type == "QMoE": | ||
| raise ValueError( | ||
| f"The 'int4_block_size' option is not supported for {self.ep} execution provider with QMoE. " | ||
| "Block-wise quantization (block_size attribute) is only supported for CPU and WebGPU execution providers." | ||
| f"The 'int4_qmoe_block_size' option is not supported for {self.ep} execution provider with QMoE. " | ||
| f"Block-wise quantization is only supported for: {', '.join(supported_blockwise_eps)}." | ||
| ) | ||
|
|
||
| self.quant_attrs = { | ||
| "int4": { | ||
| "accuracy_level": int( | ||
| extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0) | ||
| ), | ||
| "block_size": int(self.int4_block_size), | ||
| "block_size": int(self.int4_qmoe_block_size), | ||
| "qdq_block_size": int(self.int4_block_size), | ||
| "is_symmetric": extra_options.get("int4_is_symmetric", True), | ||
| "op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul",)), | ||
| "nodes_to_exclude": extra_options.get("int4_nodes_to_exclude", []), | ||
|
|
@@ -380,11 +387,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): | |
| "use_qdq": extra_options.get("use_qdq", False), | ||
| } | ||
|
|
||
| # Propagate block_size to MoE/QMoE op when supported and requested. | ||
| # QMoE on CPU/WebGPU supports block-wise quantization via the 'block_size' attribute. | ||
| # Propagate block_size to MoE/QMoE op when supported. | ||
| # QMoE on supported EPs uses block-wise quantization via the 'block_size' attribute. | ||
| # Ensure the attribute is set on the MoE op so runtime kernels can honor it. | ||
| if self.moe_attrs.get("op_type") == "QMoE" and self.ep in ["cpu", "webgpu"]: | ||
| self.moe_attrs["block_size"] = int(self.int4_block_size) | ||
| if self.moe_attrs.get("op_type") == "QMoE" and self.ep in supported_blockwise_eps: | ||
| self.moe_attrs["block_size"] = int(self.int4_qmoe_block_size) | ||
| if self.quant_type is not None: | ||
| # Create quantized attributes from quantization config | ||
| self.quant_attrs["config"] = config.quantization_config | ||
|
|
@@ -501,6 +508,7 @@ def is_gqa_supported(self) -> bool: | |
| ("webgpu", ir.DataType.FLOAT16), | ||
| ("webgpu", ir.DataType.FLOAT), | ||
| ("trt-rtx", ir.DataType.FLOAT16), | ||
| ("trt-rtx", ir.DataType.BFLOAT16), | ||
| } | ||
| return (self.ep, self.io_dtype) in valid_gqa_configurations | ||
|
|
||
|
|
@@ -703,7 +711,7 @@ def make_int4_algo_config(self, quant_method: str): | |
| def to_int4(self) -> ir.Model: | ||
| quant = MatMulNBitsQuantizer( | ||
| model=ir.to_proto(self.model), | ||
| block_size=self.quant_attrs["int4"]["block_size"], | ||
| block_size=self.quant_attrs["int4"]["qdq_block_size"], | ||
| is_symmetric=self.quant_attrs["int4"]["is_symmetric"], | ||
| accuracy_level=self.quant_attrs["int4"]["accuracy_level"], | ||
| nodes_to_exclude=self.quant_attrs["int4"]["nodes_to_exclude"], | ||
|
|
@@ -712,7 +720,22 @@ def to_int4(self) -> ir.Model: | |
| algo_config=self.quant_attrs["int4"]["algo_config"], | ||
| ) | ||
| quant.process() | ||
| return ir.from_proto(quant.model.model) | ||
| model = ir.from_proto(quant.model.model) | ||
|
|
||
| # Convert float32 scales to bfloat16 if io_dtype is bfloat16. | ||
| # MatMulNBitsQuantizer doesn't natively support bfloat16, so we saved weights as float32 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this PR, I added support for bfloat16 in the |
||
| # for quantization and now convert the resulting scales to the target io_dtype. | ||
| if self.io_dtype == ir.DataType.BFLOAT16: | ||
| for initializer in model.graph.initializers.values(): | ||
| # Scale tensors are named with "_scales" or "_DQ_scales" suffix | ||
| if initializer.name.endswith("_scales") or initializer.name.endswith("_DQ_scales"): | ||
| if initializer.dtype == ir.DataType.FLOAT: | ||
| # Convert float32 scales to bfloat16 | ||
| float32_data = initializer.const_value.numpy() | ||
| bfloat16_data = torch.from_numpy(float32_data).to(torch.bfloat16) | ||
| initializer.const_value = TorchTensor(bfloat16_data, name=initializer.name) | ||
|
|
||
| return model | ||
|
|
||
| def save_model(self, out_dir): | ||
| print(f"Saving ONNX model in {out_dir}") | ||
|
|
@@ -1056,7 +1079,14 @@ def make_matmul_op(self, matmul, basename, root_input, **kwargs): | |
|
|
||
| def make_matmul_float(self, matmul, name, root_input, **kwargs): | ||
| weight = name[1:].replace("/", ".") + ".weight" | ||
| self.make_initializer(matmul.weight.T, weight, to=self.io_dtype) | ||
| # When onnx_dtype is INT4/UINT4, weights will be quantized by MatMulNBitsQuantizer later. | ||
| # MatMulNBitsQuantizer doesn't properly support BFLOAT16 inputs, so we need to save | ||
| # weights as FLOAT32 to ensure correct quantization with proper scales. | ||
| if self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4} and self.io_dtype == ir.DataType.BFLOAT16: | ||
| weight_dtype = ir.DataType.FLOAT | ||
| else: | ||
| weight_dtype = self.io_dtype | ||
| self.make_initializer(matmul.weight.T, weight, to=weight_dtype) | ||
|
|
||
| last_dim = matmul.weight.shape[0] | ||
| output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" | ||
|
|
@@ -3227,11 +3257,15 @@ def make_qmoe_op(self, name, **kwargs): | |
| kwargs.get("weight3", ""), | ||
| kwargs.get("scales3", ""), | ||
| kwargs.get("bias3", ""), | ||
| kwargs.get("zero_points1", ""), | ||
| kwargs.get("zero_points2", ""), | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change can be reverted. The new op spec for QMoE includes optional zero points. Those optional inputs are stored as empty strings if unused and empty string inputs should not affect other models that don't support zero points. |
||
| kwargs.get("zero_points3", ""), | ||
| ] | ||
|
|
||
| # Only add zero_points inputs if they are provided (for Quark asymmetric quantization) | ||
| zero_points1 = kwargs.get("zero_points1", "") | ||
| zero_points2 = kwargs.get("zero_points2", "") | ||
| zero_points3 = kwargs.get("zero_points3", "") | ||
| if zero_points1 or zero_points2 or zero_points3: | ||
| inputs.extend([zero_points1, zero_points2, zero_points3]) | ||
|
|
||
| output = f"{name}/output_0" | ||
|
|
||
| extra_kwargs = ( | ||
|
|
@@ -3264,21 +3298,19 @@ def make_qmoe_weights(self, weights): | |
| dtype = torch.quint4x2 if self.moe_attrs["expert_weight_bits"] == 4 else torch.int8 | ||
| qweight, scales = None, None | ||
|
|
||
| # For QMoE, only use block-wise quantization when explicitly requested | ||
| # via int4_block_size and when using CPU or WebGPU execution providers, since | ||
| # block_size is only supported for these EPs in the QMoE operator. | ||
| use_blockwise_quant = "int4_block_size" in self.extra_options and self.ep in ["cpu", "webgpu"] | ||
| # Get block size from quantization attributes | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes look to be reverting to before this PR was made. There were issues discovered with the old approach that necessitated the linked PR. |
||
| block_size = self.quant_attrs["int4"]["block_size"] | ||
|
|
||
| if use_blockwise_quant: | ||
| block_size = self.quant_attrs["int4"]["block_size"] | ||
| # Use block-wise quantization if block_size > 0 | ||
| if block_size > 0: | ||
| try: | ||
| qweight, scales = self._symmetric_blockwise_quantize(weights, block_size) | ||
| self.moe_attrs["block_size"] = block_size | ||
| return qweight, scales.to(torch.float16) | ||
| except Exception as e: | ||
| raise RuntimeError(f"Block-wise quantization failed with block_size={block_size}: {e}") | ||
|
|
||
| # Use tensor-level quantization (default for QMoE) | ||
| # block_size is 0, so we're using tensor-level quantization | ||
| self.moe_attrs["block_size"] = 0 | ||
|
|
||
| # Existing tensor-level quantization implementation (fallback) | ||
|
|
@@ -3354,6 +3386,7 @@ def _symmetric_blockwise_quantize(self, weights, block_size): | |
|
|
||
| quantized_flat = quantized_int8.view(*original_shape[:-1], num_blocks * block_size) | ||
|
|
||
| # remove padding | ||
| if pad_size > 0: | ||
| quantized_flat = quantized_flat[..., :-pad_size] | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -70,7 +70,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): | |||||
| self.window_size = original_window_size | ||||||
|
|
||||||
| def make_moe(self, layer_id, mlp, root_input): | ||||||
| if self.ep in {"cpu", "cuda"}: | ||||||
| if self.ep in {"cpu", "cuda", "NvTensorRtRtx", "trt-rtx"}: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use
Suggested change
|
||||||
| self.make_moe_fused(layer_id, mlp, root_input) | ||||||
| else: | ||||||
| self.make_moe_decomposed(layer_id, mlp, root_input) | ||||||
|
|
@@ -639,24 +639,32 @@ def make_moe_fused(self, layer_id, mlp, root_input): | |||||
| down_proj_qweight_tensor = torch.stack(down_proj_qweight_list, dim=0).to(torch.uint8) | ||||||
| down_proj_scales_tensor = torch.stack(down_proj_scales_list, dim=0) | ||||||
|
|
||||||
| # qweight tensors always use the same shape regardless of quantization method | ||||||
| # Determine shape based on Quark vs non-Quark | ||||||
| pack_size = 8 // self.moe_attrs["expert_weight_bits"] | ||||||
| if has_quark_experts: | ||||||
| hidden_size_padded = self.hidden_size | ||||||
| intermediate_size_padded = self.intermediate_size | ||||||
| else: | ||||||
| hidden_size_padded = gate_up_proj_qweight_list[0].shape[-1] * pack_size | ||||||
| intermediate_size_padded = down_proj_qweight_list[0].shape[-1] * pack_size | ||||||
|
|
||||||
| # Save qweight tensors | ||||||
| self.make_initializer( | ||||||
| gate_up_proj_qweight_tensor.view(self.moe_attrs["num_experts"], -1, self.hidden_size // pack_size), | ||||||
| gate_up_proj_qweight_tensor.view(self.moe_attrs["num_experts"], -1, hidden_size_padded // pack_size), | ||||||
| gate_up_proj_weight, | ||||||
| ) | ||||||
| self.make_initializer( | ||||||
| down_proj_qweight_tensor.view( | ||||||
| self.moe_attrs["num_experts"], self.hidden_size, self.intermediate_size // pack_size | ||||||
| self.moe_attrs["num_experts"], self.hidden_size, intermediate_size_padded // pack_size | ||||||
| ), | ||||||
| down_proj_weight, | ||||||
| ) | ||||||
|
|
||||||
| # scales tensors have different shapes depending on quantization method | ||||||
| # Save scales tensors | ||||||
| self.make_initializer(gate_up_proj_scales_tensor, gate_up_proj_scales, to=self.io_dtype) | ||||||
| self.make_initializer(down_proj_scales_tensor, down_proj_scales, to=self.io_dtype) | ||||||
|
|
||||||
| # Save MoE biases as initializers | ||||||
| # Save biases (shared for all paths) | ||||||
| if has_quark_experts: | ||||||
| gate_up_bias = self.combine_quark_gate_up_biases_from_experts(mlp.experts) | ||||||
| down_bias = self.combine_quark_down_biases_from_experts(mlp.experts) | ||||||
|
|
@@ -667,7 +675,11 @@ def make_moe_fused(self, layer_id, mlp, root_input): | |||||
| self.make_initializer(gate_up_bias, gate_up_proj_bias, to=self.io_dtype) | ||||||
| self.make_initializer(down_bias, down_proj_bias, to=self.io_dtype) | ||||||
|
|
||||||
| # Single make_moe_op call with EP-based zero_points | ||||||
| # TRT-RTX doesn't support zero_points inputs | ||||||
| moe_name = f"{basename}/{op_type}" | ||||||
| use_zero_points = has_quark_experts and self.ep not in {"NvTensorRtRtx", "trt-rtx"} | ||||||
|
|
||||||
| self.make_moe_op( | ||||||
| moe_name, | ||||||
| root_input=root_input, | ||||||
|
|
@@ -678,8 +690,8 @@ def make_moe_fused(self, layer_id, mlp, root_input): | |||||
| weight2=down_proj_weight, | ||||||
| scales2=down_proj_scales, | ||||||
| bias2=down_proj_bias, | ||||||
| zero_points1=gate_up_proj_zero_points if has_quark_experts else "", | ||||||
| zero_points2=down_proj_zero_points if has_quark_experts else "", | ||||||
| zero_points1=gate_up_proj_zero_points if use_zero_points else "", | ||||||
| zero_points2=down_proj_zero_points if use_zero_points else "", | ||||||
| ) | ||||||
|
|
||||||
| # Assign output 0 of previous MoE as root input to next SkipLayerNorm | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for using a different block size in
MatMulNBitsandQMoE?