Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion atom/model_ops/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def make(
if weight_dtype is None:
weight_dtype = quant_dtype

a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype, False, block_shape)
a_shape, w_shape = _quant_flags_to_group_shape(
quant_dtype, per_act_token_quant, block_shape
)
quant_config = FusedMoEQuantConfig(
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale),
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale),
Expand Down
19 changes: 17 additions & 2 deletions atom/model_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,13 @@ def weight_loader_process(
and param.data.element_size() == loaded_weight.element_size()
):
param.data = param.data.view(loaded_weight.dtype)
param.data.copy_(post_process_func(loaded_weight))
loaded_weight = post_process_func(loaded_weight)
if (
loaded_weight.shape != param.data.shape
and loaded_weight.numel() == param.data.numel()
):
loaded_weight = loaded_weight.reshape(param.data.shape)
param.data.copy_(loaded_weight)

def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
Expand Down Expand Up @@ -716,9 +722,14 @@ def __init__(
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
if param is not getattr(self, "bias", None):
shard_size = param_data.size(self.tp_dim)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.view(1, 1)
if loaded_weight.ndim <= self.tp_dim:
# dims < tp_dim (1D per-channel scale with
# tp_dim=1)
param.weight_loader_process(param_data, loaded_weight)
return
shard_size = param_data.size(self.tp_dim)
if loaded_weight.size(self.tp_dim) == 1 and self.tp_size > 1:
loaded_weight = loaded_weight.repeat(1, self.tp_size)
start_idx = self.tp_rank * shard_size
Expand Down Expand Up @@ -768,6 +779,10 @@ def weight_loader(
elif self.quant_type == QuantType.per_Tensor:
shard_offset = loaded_shard_id
shard_size = 1
else:
# Per-channel same layout as weights
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
Expand Down
264 changes: 135 additions & 129 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,8 +1358,10 @@ def apply(

class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Supports three quantization strategies:
- per_Tensor: per-tensor weight scale, static/dynamic activation scale
- per_Token (PTPTC): per-channel weight scale, dynamic per-token activation
- per_1x128 / per_1x32 (block): block-wise weight scale, dynamic activation

Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
Expand All @@ -1378,6 +1380,7 @@ def __init__(self, quant_config: QuantizationConfig, moe: FusedMoEConfig):
self.quant_type == QuantType.per_1x128
or self.quant_type == QuantType.per_1x32
)
self.channel_quant = self.quant_type == QuantType.per_Token
self.need_normalize_e4m3fn_to_e4m3fnuz = (
self.quant_dtype == torch.float8_e4m3fnuz
)
Expand Down Expand Up @@ -1447,18 +1450,24 @@ def create_weights(
set_weight_attrs(w2_weight, extra_weight_attrs)

# WEIGHT_SCALES
if not self.block_quant:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
if self.channel_quant:
# Per-channel (PTPTC): one scale per output channel per expert.
# w13: [E, 2*N], w2: [E, hidden_size]
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
torch.ones(
num_experts,
2 * intermediate_size_per_partition,
dtype=torch.float32,
),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
torch.ones(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
else:
elif self.block_quant:
w13_weight_scale = torch.nn.Parameter(
torch.ones(
num_experts,
Expand All @@ -1480,21 +1489,26 @@ def create_weights(
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
assert self.quant_config["is_dynamic"]
else:
# Per-tensor
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)

# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
# extra_weight_attrs.update(
# {"quant_method": FusedMoeWeightScaleSupported.BLOCK.
# value} if self.block_quant else
# {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)

# INPUT_SCALES
if not self.quant_config["is_dynamic"]:
# Per-channel uses dynamic per-token activation, no static input scales.
if self.channel_quant or self.quant_config["is_dynamic"]:
layer.w13_input_scale = None
layer.w2_input_scale = None
else:
w13_input_scale = torch.nn.Parameter(
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
)
Expand All @@ -1505,134 +1519,125 @@ def create_weights(
)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None

def _normalize_weights_and_scales(self, layer: nn.Module):
if not self.need_normalize_e4m3fn_to_e4m3fnuz:
return
w13_weight, w13_weight_scale, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
layer.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = nn.Parameter(w13_weight_scale, requires_grad=False)
layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = nn.Parameter(w13_input_scale, requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = nn.Parameter(w2_input_scale, requires_grad=False)

def process_weights_after_loading(self, layer: nn.Module) -> None:
# Lazy import to avoid importing triton too early.
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled, shuffle_weights)
if self.block_quant:
self._process_block_quant(layer)
elif self.channel_quant:
self._process_channel_quant(layer)
else:
self._process_tensor_quant(layer)

# self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
# self.rocm_aiter_use_asm = (self.rocm_aiter_moe_enabled
# and envs.VLLM_ROCM_USE_AITER_ASMMOE)
def _process_block_quant(self, layer: nn.Module) -> None:
assert self.quant_config["is_dynamic"]
self._normalize_weights_and_scales(layer)

# TODO (rob): refactor block quant into separate class.
if self.block_quant:
assert self.quant_config["is_dynamic"]
if self.need_normalize_e4m3fn_to_e4m3fnuz:
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
)
w2_weight, w2_weight_scale, w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
)
else:
w13_weight = layer.w13_weight.data
w13_weight_scale = layer.w13_weight_scale.data
w2_weight = layer.w2_weight
w2_weight_scale = layer.w2_weight_scale
if not self.need_normalize_e4m3fn_to_e4m3fnuz:
layer.w13_weight = nn.Parameter(layer.w13_weight.data, requires_grad=False)
layer.w13_weight_scale = nn.Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight = nn.Parameter(layer.w2_weight.data, requires_grad=False)
layer.w2_weight_scale = nn.Parameter(
layer.w2_weight_scale.data, requires_grad=False
)

# torch.compile() cannot use Parameter subclasses.
layer.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = nn.Parameter(w13_weight_scale, requires_grad=False)
layer.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = nn.Parameter(w2_weight_scale, requires_grad=False)
shuffle_weights(layer.w13_weight, layer.w2_weight)

shuffle_weights(layer.w13_weight, layer.w2_weight)
def _process_channel_quant(self, layer: nn.Module) -> None:
"""PTPTC"""
self._normalize_weights_and_scales(layer)

return
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if not self.quant_config["is_dynamic"]:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
# if (not all_close_1d(layer.w13_input_scale)
# or not all_close_1d(layer.w2_input_scale)):
# print(
# "Found input_scales that are not equal for "
# "fp8 MoE layer. Using the maximum across experts "
# "for each layer.")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
if layer.w13_weight.data.dtype in (torch.bfloat16, torch.float16):
quant_func = get_hip_quant(QuantType.per_Token)
for expert_id in range(layer.local_num_experts):
w13_q, w13_s = quant_func(
layer.w13_weight.data[expert_id], quant_dtype=dtypes.fp8
)
if self.need_normalize_e4m3fn_to_e4m3fnuz:
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
)
layer.w13_weight.data[expert_id] = w13_q
layer.w13_weight_scale.data[expert_id] = w13_s.squeeze(-1)

w2_q, w2_s = quant_func(
layer.w2_weight.data[expert_id], quant_dtype=dtypes.fp8
)
w2_weight, w2_weight_scale, w2_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
)
layer.w2_weight.data[expert_id] = w2_q
layer.w2_weight_scale.data[expert_id] = w2_s.squeeze(-1)

shuffle_weights(layer.w13_weight, layer.w2_weight)

def _process_tensor_quant(self, layer: nn.Module) -> None:
if not self.quant_config["is_dynamic"]:
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False
)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False
)

self._normalize_weights_and_scales(layer)

assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False
)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
quant_func = get_hip_quant(self.quant_type)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
quant_func(dq_weight, max_w13_scales[expert_id])
)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False
)

# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.local_num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start : start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id],
)
quant_func = get_hip_quant(self.quant_type)
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
quant_func(dq_weight, max_w13_scales[expert_id])
)
start += shard_size
start += shard_size

shuffle_weights(layer.w13_weight, layer.w2_weight)
shuffle_weights(layer.w13_weight, layer.w2_weight)

layer.w13_weight_scale = torch.nn.Parameter(
max_w13_scales, requires_grad=False
)
return
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return fp8_w8a8_moe_quant_config(
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=None,
)
if self.channel_quant:
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
)
else:
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=None,
)

def apply(
self,
Expand Down Expand Up @@ -1669,7 +1674,8 @@ def apply(
num_fused_shared_experts=layer.num_fused_shared_experts,
routed_scaling_factor=layer.routed_scaling_factor,
)
# per_Tensor not support num_local_tokens so not use mori
# per_Tensor doesn't support num_local_tokens, so fallback to
# rocm_aiter_fused_moe when using per-tensor or no modular kernel.
if self.quant_type == QuantType.per_Tensor or self.fused_experts is None:
return torch.ops.aiter.rocm_aiter_fused_moe(
x,
Expand Down
Loading