From d6e598f635fefa6ec39e24c07001d7ea3ca8865b Mon Sep 17 00:00:00 2001 From: "Zheng, Beilei" Date: Mon, 16 Jun 2025 23:27:23 -0700 Subject: [PATCH 1/3] Enable FP8 llama3.1-8b --- python/sglang/srt/layers/quantization/fp8.py | 55 ++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 8de127c6c1c2..8a089ec030e1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn.functional as F @@ -154,6 +154,39 @@ def get_scaled_act_names(self) -> List[str]: return [] +def requantize_with_max_scale_cpu( + weight: torch.Tensor, weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = (weight_scale[-1] + > torch.finfo(torch.float8_e4m3fn).min) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], + weight_scale[idx]) + # weight[start:end, :], _ = ops.scaled_fp8_quant( + # weight_dq, max_w_scale) + weight[start:end, :] = torch.clamp( + (weight_dq / max_w_scale), + torch.finfo(torch.float8_e4m3fn).min, + torch.finfo(torch.float8_e4m3fn).max + ).to(torch.float8_e4m3fn) + start = end + + return max_w_scale, weight + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -366,14 +399,14 @@ def process_weights_after_loading(self, layer: Module) -> None: if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) - weight_scale, weight = requantize_with_max_scale( + weight_scale, weight = requantize_with_max_scale_cpu( weight=weight, weight_scale=weight_scale, logical_widths=layer.logical_widths, ) # Update layer with new values. - layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight = Parameter(weight.t().contiguous(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter( @@ -422,6 +455,22 @@ def apply( input_scale=None, bias=bias, ) + if self.quant_config.activation_scheme == "static": + q_input = torch.ops.quantized_decomposed.quantize_per_tensor( + input=x, + scale=layer.input_scale, + zero_point=0, + quant_min=int(torch.finfo(torch.float8_e4m3fn).min), + quant_max=int(torch.finfo(torch.float8_e4m3fn).max), + dtype=torch.float8_e4m3fn, + ) + return torch._scaled_mm( + q_input, + layer.weight, + bias = bias, + out_dtype=x.dtype, + scale_a = layer.input_scale, + scale_b = layer.weight_scale,) return apply_fp8_linear( input=x, From 030557c0b344fc0d61df453324f815ac782e28a8 Mon Sep 17 00:00:00 2001 From: "Zheng, Beilei" Date: Tue, 17 Jun 2025 01:42:23 -0700 Subject: [PATCH 2/3] support prompt file: https://intel-extension-for-pytorch.s3.amazonaws.com/miscellaneous/llm/prompt-3.json --- python/sglang/bench_one_batch.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 0deaa4fe6abd..3044a48bf5d0 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -491,9 +491,29 @@ def latency_test( else: with open(bench_args.prompt_filename, "r") as pf: prompt_pool = json.load(pf) - prompt_dict = prompt_pool[str(bench_args.input_len[0])] - for index in range(bench_args.batch_size[0]): - custom_prompts.append(prompt_dict[str(index)]) + prompt_dict = None + if str(bench_args.input_len[0]) in prompt_pool: + prompt_dict = prompt_pool[str(bench_args.input_len[0])] + else: + for key in prompt_pool.keys(): + if key in server_args.model_path: + prompt_dict = prompt_pool[key][str(bench_args.input_len[0])] + break + if prompt_dict is None: + rank_print( + f"Custom prompt file {bench_args.prompt_filename} does not contain prompts for {server_args.model_path} and" + f"input length {bench_args.input_len[0]}. Using dummy data..." + ) + else: + for index in range(bench_args.batch_size[0]): + if isinstance(prompt_dict, str): + custom_prompts.append(prompt_dict) + elif str(index) in prompt_dict: + custom_prompts.append(prompt_dict[str(index)]) + else: + rank_print( + f"Custom prompt file {bench_args.prompt_filename} does not contain prompt for batch {index}. Using dummy data..." + ) custom_inputs = [tokenizer.encode(p.strip()) for p in custom_prompts] From 0dcc42adaf8ac8aa201cf5642f22168d724b493c Mon Sep 17 00:00:00 2001 From: "Zheng, Beilei" Date: Tue, 17 Jun 2025 20:24:23 -0700 Subject: [PATCH 3/3] fix lint --- python/sglang/srt/layers/quantization/fp8.py | 51 ++++++++++---------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 8a089ec030e1..b4352e24f344 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -7,19 +7,6 @@ import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, - convert_to_channelwise, - per_tensor_dequantize, - requantize_with_max_scale, -) from sglang.srt.cpu_utils import _process_weight_after_loading, cpu_has_amx_support from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -53,6 +40,19 @@ print_warning_once, set_weight_attrs, ) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + convert_to_channelwise, + per_tensor_dequantize, + requantize_with_max_scale, +) if cpu_has_amx_support(): import sgl_kernel.cpu @@ -136,9 +136,8 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -155,8 +154,8 @@ def get_scaled_act_names(self) -> List[str]: def requantize_with_max_scale_cpu( - weight: torch.Tensor, weight_scale: torch.Tensor, - logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int] +) -> Tuple[torch.Tensor, torch.Tensor]: # Max scale to be used for requanitzation. max_w_scale = weight_scale.max() @@ -166,27 +165,28 @@ def requantize_with_max_scale_cpu( # from disk in this case. Skip requantization in this case (since) # we already are quantized with the single scale. # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 - unfused_module_in_checkpoint = (weight_scale[-1] - > torch.finfo(torch.float8_e4m3fn).min) + unfused_module_in_checkpoint = ( + weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) # If unfused checkpoint, need requanize with the single scale. if unfused_module_in_checkpoint: start = 0 for idx, logical_width in enumerate(logical_widths): end = start + logical_width - weight_dq = per_tensor_dequantize(weight[start:end, :], - weight_scale[idx]) + weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) # weight[start:end, :], _ = ops.scaled_fp8_quant( # weight_dq, max_w_scale) weight[start:end, :] = torch.clamp( (weight_dq / max_w_scale), torch.finfo(torch.float8_e4m3fn).min, - torch.finfo(torch.float8_e4m3fn).max + torch.finfo(torch.float8_e4m3fn).max, ).to(torch.float8_e4m3fn) start = end return max_w_scale, weight + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -467,10 +467,11 @@ def apply( return torch._scaled_mm( q_input, layer.weight, - bias = bias, + bias=bias, out_dtype=x.dtype, - scale_a = layer.input_scale, - scale_b = layer.weight_scale,) + scale_a=layer.input_scale, + scale_b=layer.weight_scale, + ) return apply_fp8_linear( input=x,