From 0b1b0bbeeca656546926cf38325a4ecf44dd5a76 Mon Sep 17 00:00:00 2001 From: ZhangLirong-amd Date: Fri, 6 Mar 2026 12:58:51 +0800 Subject: [PATCH 1/3] support PTPC fp8 in Moe --- atom/model_ops/linear.py | 16 ++- atom/model_ops/moe.py | 280 +++++++++++++++++++++------------------ 2 files changed, 165 insertions(+), 131 deletions(-) diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index a3d7b4ef..495ca80a 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -314,7 +314,10 @@ def weight_loader_process( and param.data.element_size() == loaded_weight.element_size() ): param.data = param.data.view(loaded_weight.dtype) - param.data.copy_(post_process_func(loaded_weight)) + loaded_weight = post_process_func(loaded_weight) + if loaded_weight.shape != param.data.shape and loaded_weight.numel() == param.data.numel(): + loaded_weight = loaded_weight.reshape(param.data.shape) + param.data.copy_(loaded_weight) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data @@ -716,9 +719,14 @@ def __init__( def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data if param is not getattr(self, "bias", None): - shard_size = param_data.size(self.tp_dim) if len(loaded_weight.shape) == 0: loaded_weight = loaded_weight.view(1, 1) + if loaded_weight.ndim <= self.tp_dim: + # dims < tp_dim (1D per-channel scale with + # tp_dim=1) + param.weight_loader_process(param_data, loaded_weight) + return + shard_size = param_data.size(self.tp_dim) if loaded_weight.size(self.tp_dim) == 1 and self.tp_size > 1: loaded_weight = loaded_weight.repeat(1, self.tp_size) start_idx = self.tp_rank * shard_size @@ -768,6 +776,10 @@ def weight_loader( elif self.quant_type == QuantType.per_Tensor: shard_offset = loaded_shard_id shard_size = 1 + else: + # Per-channel same layout as weights + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] else: shard_offset = sum(self.output_sizes[:loaded_shard_id]) shard_size = self.output_sizes[loaded_shard_id] diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index b6820cdd..925e0877 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1358,8 +1358,10 @@ def apply( class Fp8MoEMethod(FusedMoEMethodBase): """MoE method for FP8. - Supports loading FP8 checkpoints with static weight scale and - dynamic/static activation scale. + Supports three quantization strategies: + - per_Tensor: per-tensor weight scale, static/dynamic activation scale + - per_Token (PTPTC): per-channel weight scale, dynamic per-token activation + - per_1x128 / per_1x32 (block): block-wise weight scale, dynamic activation Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after @@ -1378,6 +1380,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig): self.quant_type == QuantType.per_1x128 or self.quant_type == QuantType.per_1x32 ) + self.channel_quant = self.quant_type == QuantType.per_Token self.need_normalize_e4m3fn_to_e4m3fnuz = ( self.quant_dtype == torch.float8_e4m3fnuz ) @@ -1447,18 +1450,26 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - if not self.block_quant: - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. + if self.channel_quant: + # Per-channel (PTPTC): one scale per output channel per expert. + # w13: [E, 2*N], w2: [E, hidden_size] w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + dtype=torch.float32, + ), + requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False + torch.ones( + num_experts, hidden_size, dtype=torch.float32 + ), + requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) - else: + elif self.block_quant: w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, @@ -1480,21 +1491,26 @@ def create_weights( layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) assert self.quant_config["is_dynamic"] + else: + # Per-tensor + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) - # Add the quantization method used (per tensor/grouped/channel) - # to ensure the weight scales are loaded in properly - # extra_weight_attrs.update( - # {"quant_method": FusedMoeWeightScaleSupported.BLOCK. - # value} if self.block_quant else - # {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES - if not self.quant_config["is_dynamic"]: + # Per-channel uses dynamic per-token activation, no static input scales. + if self.channel_quant or self.quant_config["is_dynamic"]: + layer.w13_input_scale = None + layer.w2_input_scale = None + else: w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) @@ -1505,134 +1521,139 @@ def create_weights( ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) - else: - layer.w13_input_scale = None - layer.w2_input_scale = None + + def _normalize_weights_and_scales(self, layer: nn.Module): + if not self.need_normalize_e4m3fn_to_e4m3fnuz: + return + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + ) + layer.w13_weight = nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = nn.Parameter(w13_weight_scale, requires_grad=False) + layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = nn.Parameter( + w13_input_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = nn.Parameter( + w2_input_scale, requires_grad=False + ) def process_weights_after_loading(self, layer: nn.Module) -> None: - # Lazy import to avoid importing triton too early. - # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - # is_rocm_aiter_moe_enabled, shuffle_weights) + if self.block_quant: + self._process_block_quant(layer) + elif self.channel_quant: + self._process_channel_quant(layer) + else: + self._process_tensor_quant(layer) - # self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() - # self.rocm_aiter_use_asm = (self.rocm_aiter_moe_enabled - # and envs.VLLM_ROCM_USE_AITER_ASMMOE) + def _process_block_quant(self, layer: nn.Module) -> None: + assert self.quant_config["is_dynamic"] + self._normalize_weights_and_scales(layer) - # TODO (rob): refactor block quant into separate class. - if self.block_quant: - assert self.quant_config["is_dynamic"] - if self.need_normalize_e4m3fn_to_e4m3fnuz: - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) - ) - w2_weight, w2_weight_scale, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) - ) - else: - w13_weight = layer.w13_weight.data - w13_weight_scale = layer.w13_weight_scale.data - w2_weight = layer.w2_weight - w2_weight_scale = layer.w2_weight_scale + if not self.need_normalize_e4m3fn_to_e4m3fnuz: + layer.w13_weight = nn.Parameter( + layer.w13_weight.data, requires_grad=False + ) + layer.w13_weight_scale = nn.Parameter( + layer.w13_weight_scale.data, requires_grad=False + ) + layer.w2_weight = nn.Parameter( + layer.w2_weight.data, requires_grad=False + ) + layer.w2_weight_scale = nn.Parameter( + layer.w2_weight_scale.data, requires_grad=False + ) - # torch.compile() cannot use Parameter subclasses. - layer.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = nn.Parameter(w13_weight_scale, requires_grad=False) - layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) + shuffle_weights(layer.w13_weight, layer.w2_weight) - shuffle_weights(layer.w13_weight, layer.w2_weight) + def _process_channel_quant(self, layer: nn.Module) -> None: + """PTPTC""" + self._normalize_weights_and_scales(layer) - return - else: - # Fp8 moe kernels require a single activation scale. - # We take the max of all the scales in case they differ. - if not self.quant_config["is_dynamic"]: - if layer.w13_input_scale is None or layer.w2_input_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - # if (not all_close_1d(layer.w13_input_scale) - # or not all_close_1d(layer.w2_input_scale)): - # print( - # "Found input_scales that are not equal for " - # "fp8 MoE layer. Using the maximum across experts " - # "for each layer.") - layer.w13_input_scale = torch.nn.Parameter( - layer.w13_input_scale.max(), requires_grad=False - ) - layer.w2_input_scale = torch.nn.Parameter( - layer.w2_input_scale.max(), requires_grad=False + if layer.w13_weight.data.dtype in (torch.bfloat16, torch.float16): + quant_func = get_hip_quant(QuantType.per_Token) + for expert_id in range(layer.local_num_experts): + w13_q, w13_s = quant_func( + layer.w13_weight.data[expert_id], quant_dtype=dtypes.fp8 ) - if self.need_normalize_e4m3fn_to_e4m3fnuz: - # Normalize the weights and scales - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) + layer.w13_weight.data[expert_id] = w13_q + layer.w13_weight_scale.data[expert_id] = w13_s.squeeze(-1) + + w2_q, w2_s = quant_func( + layer.w2_weight.data[expert_id], quant_dtype=dtypes.fp8 ) - w2_weight, w2_weight_scale, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) + layer.w2_weight.data[expert_id] = w2_q + layer.w2_weight_scale.data[expert_id] = w2_s.squeeze(-1) + + shuffle_weights(layer.w13_weight, layer.w2_weight) + + def _process_tensor_quant(self, layer: nn.Module) -> None: + if not self.quant_config["is_dynamic"]: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." ) - # Reset the parameter - layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) - layer.w13_weight_scale = torch.nn.Parameter( - w13_weight_scale, requires_grad=False + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + self._normalize_weights_and_scales(layer) + + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.local_num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], ) - if w13_input_scale is not None: - layer.w13_input_scale = torch.nn.Parameter( - w13_input_scale, requires_grad=False - ) - layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - layer.w2_weight_scale = torch.nn.Parameter( - w2_weight_scale, requires_grad=False + quant_func = get_hip_quant(self.quant_type) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + quant_func(dq_weight, max_w13_scales[expert_id]) ) - if w2_input_scale is not None: - layer.w2_input_scale = torch.nn.Parameter( - w2_input_scale, requires_grad=False - ) + start += shard_size - # Fp8 moe kernel needs single weight scale for w13 per expert. - # We take the max then dequant and requant each expert. - assert layer.w13_weight_scale is not None - shard_size = layer.intermediate_size_per_partition - max_w13_scales = layer.w13_weight_scale.max(dim=1).values - for expert_id in range(layer.local_num_experts): - start = 0 - for shard_id in range(2): - dq_weight = per_tensor_dequantize( - layer.w13_weight[expert_id][start : start + shard_size, :], - layer.w13_weight_scale[expert_id][shard_id], - ) - quant_func = get_hip_quant(self.quant_type) - layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( - quant_func(dq_weight, max_w13_scales[expert_id]) - ) - start += shard_size - - shuffle_weights(layer.w13_weight, layer.w2_weight) + shuffle_weights(layer.w13_weight, layer.w2_weight) - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) - return + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: - return fp8_w8a8_moe_quant_config( - w1_scale=(layer.w13_weight_scale), - w2_scale=(layer.w2_weight_scale), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=None, - ) + if self.channel_quant: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + per_act_token_quant=True, + ) + else: + return fp8_w8a8_moe_quant_config( + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=None, + ) def apply( self, @@ -1669,7 +1690,8 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, routed_scaling_factor=layer.routed_scaling_factor, ) - # per_Tensor not support num_local_tokens so not use mori + # per_Tensor doesn't support num_local_tokens, so fallback to + # rocm_aiter_fused_moe when using per-tensor or no modular kernel. if self.quant_type == QuantType.per_Tensor or self.fused_experts is None: return torch.ops.aiter.rocm_aiter_fused_moe( x, From e1c030b9d5ada73a77710c9a3b27b1d64af2654f Mon Sep 17 00:00:00 2001 From: ZhangLirong-amd Date: Fri, 6 Mar 2026 13:23:22 +0800 Subject: [PATCH 2/3] add per channel in config --- atom/model_ops/fused_moe/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/atom/model_ops/fused_moe/config.py b/atom/model_ops/fused_moe/config.py index ff3a2526..46f198e1 100644 --- a/atom/model_ops/fused_moe/config.py +++ b/atom/model_ops/fused_moe/config.py @@ -198,7 +198,7 @@ def make( if weight_dtype is None: weight_dtype = quant_dtype - a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, False, block_shape) + a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, per_act_token_quant, block_shape) quant_config = FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale), _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale), From 0d24431a23f2434e9f186c5e18299b318c126b48 Mon Sep 17 00:00:00 2001 From: ZhangLirong-amd Date: Fri, 6 Mar 2026 15:41:19 +0800 Subject: [PATCH 3/3] format --- atom/model_ops/fused_moe/config.py | 4 +++- atom/model_ops/linear.py | 5 ++++- atom/model_ops/moe.py | 36 +++++++++--------------------- 3 files changed, 17 insertions(+), 28 deletions(-) diff --git a/atom/model_ops/fused_moe/config.py b/atom/model_ops/fused_moe/config.py index 46f198e1..5ab04b54 100644 --- a/atom/model_ops/fused_moe/config.py +++ b/atom/model_ops/fused_moe/config.py @@ -198,7 +198,9 @@ def make( if weight_dtype is None: weight_dtype = quant_dtype - a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, per_act_token_quant, block_shape) + a_shape, w_shape = _quant_flags_to_group_shape( + quant_dtype, per_act_token_quant, block_shape + ) quant_config = FusedMoEQuantConfig( _a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale), _a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale), diff --git a/atom/model_ops/linear.py b/atom/model_ops/linear.py index 495ca80a..8f470ef5 100644 --- a/atom/model_ops/linear.py +++ b/atom/model_ops/linear.py @@ -315,7 +315,10 @@ def weight_loader_process( ): param.data = param.data.view(loaded_weight.dtype) loaded_weight = post_process_func(loaded_weight) - if loaded_weight.shape != param.data.shape and loaded_weight.numel() == param.data.numel(): + if ( + loaded_weight.shape != param.data.shape + and loaded_weight.numel() == param.data.numel() + ): loaded_weight = loaded_weight.reshape(param.data.shape) param.data.copy_(loaded_weight) diff --git a/atom/model_ops/moe.py b/atom/model_ops/moe.py index 925e0877..5953c770 100644 --- a/atom/model_ops/moe.py +++ b/atom/model_ops/moe.py @@ -1462,9 +1462,7 @@ def create_weights( requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( - torch.ones( - num_experts, hidden_size, dtype=torch.float32 - ), + torch.ones(num_experts, hidden_size, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) @@ -1525,28 +1523,20 @@ def create_weights( def _normalize_weights_and_scales(self, layer: nn.Module): if not self.need_normalize_e4m3fn_to_e4m3fnuz: return - w13_weight, w13_weight_scale, w13_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale - ) + w13_weight, w13_weight_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale ) - w2_weight, w2_weight_scale, w2_input_scale = ( - normalize_e4m3fn_to_e4m3fnuz( - layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale - ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale ) layer.w13_weight = nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale = nn.Parameter(w13_weight_scale, requires_grad=False) layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False) if w13_input_scale is not None: - layer.w13_input_scale = nn.Parameter( - w13_input_scale, requires_grad=False - ) + layer.w13_input_scale = nn.Parameter(w13_input_scale, requires_grad=False) if w2_input_scale is not None: - layer.w2_input_scale = nn.Parameter( - w2_input_scale, requires_grad=False - ) + layer.w2_input_scale = nn.Parameter(w2_input_scale, requires_grad=False) def process_weights_after_loading(self, layer: nn.Module) -> None: if self.block_quant: @@ -1561,15 +1551,11 @@ def _process_block_quant(self, layer: nn.Module) -> None: self._normalize_weights_and_scales(layer) if not self.need_normalize_e4m3fn_to_e4m3fnuz: - layer.w13_weight = nn.Parameter( - layer.w13_weight.data, requires_grad=False - ) + layer.w13_weight = nn.Parameter(layer.w13_weight.data, requires_grad=False) layer.w13_weight_scale = nn.Parameter( layer.w13_weight_scale.data, requires_grad=False ) - layer.w2_weight = nn.Parameter( - layer.w2_weight.data, requires_grad=False - ) + layer.w2_weight = nn.Parameter(layer.w2_weight.data, requires_grad=False) layer.w2_weight_scale = nn.Parameter( layer.w2_weight_scale.data, requires_grad=False ) @@ -1631,9 +1617,7 @@ def _process_tensor_quant(self, layer: nn.Module) -> None: shuffle_weights(layer.w13_weight, layer.w2_weight) - layer.w13_weight_scale = torch.nn.Parameter( - max_w13_scales, requires_grad=False - ) + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) def get_fused_moe_quant_config( self, layer: torch.nn.Module