diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index e0fa1bbb0..8b86651b7 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -135,7 +135,7 @@ def parse_hf_token(hf_token): def set_io_dtype(precision, execution_provider, extra_options) -> ir.DataType: int4_cpu = precision == "int4" and execution_provider == "cpu" fp32_webgpu = execution_provider == "webgpu" and extra_options.get("use_webgpu_fp32", False) - bf16_cuda = precision == "int4" and execution_provider == "cuda" and extra_options.get("use_cuda_bf16", False) + bf16_cuda = precision == "int4" and execution_provider in {"cuda", "trt-rtx"} and extra_options.get("use_cuda_bf16", False) if precision in {"int8", "fp32"} or int4_cpu or fp32_webgpu: # FP32 precision @@ -403,8 +403,10 @@ def get_args(): 2 is fp16. 1 is fp32. Default is 4 for the CPU EP and 0 for non-CPU EPs. - int4_block_size = 16/32/64/128/256: Specify the block size for int4 quantization. + int4_block_size = 16/32/64/128/256: Specify the block size for int4 quantization (MatMulNBits). Default value is 32. + int4_qmoe_block_size = 16/32/64/128/256: Specify the block size for QMoE expert weights quantization. + Default is 128 for trt-rtx, 0 (tensor-level) for others. Supported EPs: cpu, webgpu, trt-rtx. int4_is_symmetric = Quantize the weights symmetrically. Default is true. If true, quantization is done to int4. If false, quantization is done to uint4. int4_op_types_to_quantize = MatMul/Gather: Specify op types to target for int4 quantization. @@ -469,7 +471,7 @@ def get_args(): args = parser.parse_args() print( - "Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, FP16 TRT-RTX, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WebGPU" + "Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, BF16 CUDA, FP16 TRT-RTX, BF16 TRT-RTX, INT4 CPU, INT4 CUDA, INT4 DML, INT4 WebGPU" ) return args diff --git a/src/python/py/models/builders/base.py b/src/python/py/models/builders/base.py index 9267a5ec5..849c425dc 100644 --- a/src/python/py/models/builders/base.py +++ b/src/python/py/models/builders/base.py @@ -359,11 +359,17 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): int4_algo_config = self.make_int4_algo_config(extra_options.get("int4_algo_config", "default")) self.int4_block_size = extra_options.get("int4_block_size", 32) - # Validate that only CPU and WebGPU EPs support int4_block_size for QMoE - if self.ep not in ["cpu", "webgpu"] and "int4_block_size" in extra_options and moe_op_type == "QMoE": + # CPU, WebGPU, and TRT-RTX support block-wise quantization for QMoE. + # TRT-RTX defaults to 128; CPU/WebGPU default to 0 (tensor-level) for backward compatibility. + supported_blockwise_eps = ["cpu", "webgpu", "trt-rtx", "NvTensorRtRtx"] + default_qmoe_block_size = 128 if self.ep in ["trt-rtx", "NvTensorRtRtx"] else 0 + self.int4_qmoe_block_size = extra_options.get("int4_qmoe_block_size", default_qmoe_block_size) + + # Validate that unsupported EPs don't explicitly request block-wise quantization + if self.ep not in supported_blockwise_eps and "int4_qmoe_block_size" in extra_options and moe_op_type == "QMoE": raise ValueError( - f"The 'int4_block_size' option is not supported for {self.ep} execution provider with QMoE. " - "Block-wise quantization (block_size attribute) is only supported for CPU and WebGPU execution providers." + f"The 'int4_qmoe_block_size' option is not supported for {self.ep} execution provider with QMoE. " + f"Block-wise quantization is only supported for: {', '.join(supported_blockwise_eps)}." ) self.quant_attrs = { @@ -371,7 +377,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "accuracy_level": int( extra_options.get("int4_accuracy_level", 4 if self.ep in ["cpu", "webgpu"] else 0) ), - "block_size": int(self.int4_block_size), + "block_size": int(self.int4_qmoe_block_size), + "qdq_block_size": int(self.int4_block_size), "is_symmetric": extra_options.get("int4_is_symmetric", True), "op_types_to_quantize": extra_options.get("int4_op_types_to_quantize", ("MatMul",)), "nodes_to_exclude": extra_options.get("int4_nodes_to_exclude", []), @@ -380,11 +387,11 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "use_qdq": extra_options.get("use_qdq", False), } - # Propagate block_size to MoE/QMoE op when supported and requested. - # QMoE on CPU/WebGPU supports block-wise quantization via the 'block_size' attribute. + # Propagate block_size to MoE/QMoE op when supported. + # QMoE on supported EPs uses block-wise quantization via the 'block_size' attribute. # Ensure the attribute is set on the MoE op so runtime kernels can honor it. - if self.moe_attrs.get("op_type") == "QMoE" and self.ep in ["cpu", "webgpu"]: - self.moe_attrs["block_size"] = int(self.int4_block_size) + if self.moe_attrs.get("op_type") == "QMoE" and self.ep in supported_blockwise_eps: + self.moe_attrs["block_size"] = int(self.int4_qmoe_block_size) if self.quant_type is not None: # Create quantized attributes from quantization config self.quant_attrs["config"] = config.quantization_config @@ -501,6 +508,7 @@ def is_gqa_supported(self) -> bool: ("webgpu", ir.DataType.FLOAT16), ("webgpu", ir.DataType.FLOAT), ("trt-rtx", ir.DataType.FLOAT16), + ("trt-rtx", ir.DataType.BFLOAT16), } return (self.ep, self.io_dtype) in valid_gqa_configurations @@ -703,7 +711,7 @@ def make_int4_algo_config(self, quant_method: str): def to_int4(self) -> ir.Model: quant = MatMulNBitsQuantizer( model=ir.to_proto(self.model), - block_size=self.quant_attrs["int4"]["block_size"], + block_size=self.quant_attrs["int4"]["qdq_block_size"], is_symmetric=self.quant_attrs["int4"]["is_symmetric"], accuracy_level=self.quant_attrs["int4"]["accuracy_level"], nodes_to_exclude=self.quant_attrs["int4"]["nodes_to_exclude"], @@ -712,7 +720,22 @@ def to_int4(self) -> ir.Model: algo_config=self.quant_attrs["int4"]["algo_config"], ) quant.process() - return ir.from_proto(quant.model.model) + model = ir.from_proto(quant.model.model) + + # Convert float32 scales to bfloat16 if io_dtype is bfloat16. + # MatMulNBitsQuantizer doesn't natively support bfloat16, so we saved weights as float32 + # for quantization and now convert the resulting scales to the target io_dtype. + if self.io_dtype == ir.DataType.BFLOAT16: + for initializer in model.graph.initializers.values(): + # Scale tensors are named with "_scales" or "_DQ_scales" suffix + if initializer.name.endswith("_scales") or initializer.name.endswith("_DQ_scales"): + if initializer.dtype == ir.DataType.FLOAT: + # Convert float32 scales to bfloat16 + float32_data = initializer.const_value.numpy() + bfloat16_data = torch.from_numpy(float32_data).to(torch.bfloat16) + initializer.const_value = TorchTensor(bfloat16_data, name=initializer.name) + + return model def save_model(self, out_dir): print(f"Saving ONNX model in {out_dir}") @@ -1056,7 +1079,14 @@ def make_matmul_op(self, matmul, basename, root_input, **kwargs): def make_matmul_float(self, matmul, name, root_input, **kwargs): weight = name[1:].replace("/", ".") + ".weight" - self.make_initializer(matmul.weight.T, weight, to=self.io_dtype) + # When onnx_dtype is INT4/UINT4, weights will be quantized by MatMulNBitsQuantizer later. + # MatMulNBitsQuantizer doesn't properly support BFLOAT16 inputs, so we need to save + # weights as FLOAT32 to ensure correct quantization with proper scales. + if self.onnx_dtype in {ir.DataType.INT4, ir.DataType.UINT4} and self.io_dtype == ir.DataType.BFLOAT16: + weight_dtype = ir.DataType.FLOAT + else: + weight_dtype = self.io_dtype + self.make_initializer(matmul.weight.T, weight, to=weight_dtype) last_dim = matmul.weight.shape[0] output = "logits" if kwargs.get("logits", False) else f"{name}/output_0" @@ -3227,11 +3257,15 @@ def make_qmoe_op(self, name, **kwargs): kwargs.get("weight3", ""), kwargs.get("scales3", ""), kwargs.get("bias3", ""), - kwargs.get("zero_points1", ""), - kwargs.get("zero_points2", ""), - kwargs.get("zero_points3", ""), ] + # Only add zero_points inputs if they are provided (for Quark asymmetric quantization) + zero_points1 = kwargs.get("zero_points1", "") + zero_points2 = kwargs.get("zero_points2", "") + zero_points3 = kwargs.get("zero_points3", "") + if zero_points1 or zero_points2 or zero_points3: + inputs.extend([zero_points1, zero_points2, zero_points3]) + output = f"{name}/output_0" extra_kwargs = ( @@ -3264,13 +3298,11 @@ def make_qmoe_weights(self, weights): dtype = torch.quint4x2 if self.moe_attrs["expert_weight_bits"] == 4 else torch.int8 qweight, scales = None, None - # For QMoE, only use block-wise quantization when explicitly requested - # via int4_block_size and when using CPU or WebGPU execution providers, since - # block_size is only supported for these EPs in the QMoE operator. - use_blockwise_quant = "int4_block_size" in self.extra_options and self.ep in ["cpu", "webgpu"] + # Get block size from quantization attributes + block_size = self.quant_attrs["int4"]["block_size"] - if use_blockwise_quant: - block_size = self.quant_attrs["int4"]["block_size"] + # Use block-wise quantization if block_size > 0 + if block_size > 0: try: qweight, scales = self._symmetric_blockwise_quantize(weights, block_size) self.moe_attrs["block_size"] = block_size @@ -3278,7 +3310,7 @@ def make_qmoe_weights(self, weights): except Exception as e: raise RuntimeError(f"Block-wise quantization failed with block_size={block_size}: {e}") - # Use tensor-level quantization (default for QMoE) + # block_size is 0, so we're using tensor-level quantization self.moe_attrs["block_size"] = 0 # Existing tensor-level quantization implementation (fallback) @@ -3354,6 +3386,7 @@ def _symmetric_blockwise_quantize(self, weights, block_size): quantized_flat = quantized_int8.view(*original_shape[:-1], num_blocks * block_size) + # remove padding if pad_size > 0: quantized_flat = quantized_flat[..., :-pad_size] diff --git a/src/python/py/models/builders/gptoss.py b/src/python/py/models/builders/gptoss.py index 65eb7a94a..324caf160 100644 --- a/src/python/py/models/builders/gptoss.py +++ b/src/python/py/models/builders/gptoss.py @@ -70,7 +70,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): self.window_size = original_window_size def make_moe(self, layer_id, mlp, root_input): - if self.ep in {"cpu", "cuda"}: + if self.ep in {"cpu", "cuda", "NvTensorRtRtx", "trt-rtx"}: self.make_moe_fused(layer_id, mlp, root_input) else: self.make_moe_decomposed(layer_id, mlp, root_input) @@ -639,24 +639,32 @@ def make_moe_fused(self, layer_id, mlp, root_input): down_proj_qweight_tensor = torch.stack(down_proj_qweight_list, dim=0).to(torch.uint8) down_proj_scales_tensor = torch.stack(down_proj_scales_list, dim=0) - # qweight tensors always use the same shape regardless of quantization method + # Determine shape based on Quark vs non-Quark pack_size = 8 // self.moe_attrs["expert_weight_bits"] + if has_quark_experts: + hidden_size_padded = self.hidden_size + intermediate_size_padded = self.intermediate_size + else: + hidden_size_padded = gate_up_proj_qweight_list[0].shape[-1] * pack_size + intermediate_size_padded = down_proj_qweight_list[0].shape[-1] * pack_size + + # Save qweight tensors self.make_initializer( - gate_up_proj_qweight_tensor.view(self.moe_attrs["num_experts"], -1, self.hidden_size // pack_size), + gate_up_proj_qweight_tensor.view(self.moe_attrs["num_experts"], -1, hidden_size_padded // pack_size), gate_up_proj_weight, ) self.make_initializer( down_proj_qweight_tensor.view( - self.moe_attrs["num_experts"], self.hidden_size, self.intermediate_size // pack_size + self.moe_attrs["num_experts"], self.hidden_size, intermediate_size_padded // pack_size ), down_proj_weight, ) - # scales tensors have different shapes depending on quantization method + # Save scales tensors self.make_initializer(gate_up_proj_scales_tensor, gate_up_proj_scales, to=self.io_dtype) self.make_initializer(down_proj_scales_tensor, down_proj_scales, to=self.io_dtype) - # Save MoE biases as initializers + # Save biases (shared for all paths) if has_quark_experts: gate_up_bias = self.combine_quark_gate_up_biases_from_experts(mlp.experts) down_bias = self.combine_quark_down_biases_from_experts(mlp.experts) @@ -667,7 +675,11 @@ def make_moe_fused(self, layer_id, mlp, root_input): self.make_initializer(gate_up_bias, gate_up_proj_bias, to=self.io_dtype) self.make_initializer(down_bias, down_proj_bias, to=self.io_dtype) + # Single make_moe_op call with EP-based zero_points + # TRT-RTX doesn't support zero_points inputs moe_name = f"{basename}/{op_type}" + use_zero_points = has_quark_experts and self.ep not in {"NvTensorRtRtx", "trt-rtx"} + self.make_moe_op( moe_name, root_input=root_input, @@ -678,8 +690,8 @@ def make_moe_fused(self, layer_id, mlp, root_input): weight2=down_proj_weight, scales2=down_proj_scales, bias2=down_proj_bias, - zero_points1=gate_up_proj_zero_points if has_quark_experts else "", - zero_points2=down_proj_zero_points if has_quark_experts else "", + zero_points1=gate_up_proj_zero_points if use_zero_points else "", + zero_points2=down_proj_zero_points if use_zero_points else "", ) # Assign output 0 of previous MoE as root input to next SkipLayerNorm