diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 0deaa4fe6abd..3044a48bf5d0 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -491,9 +491,29 @@ def latency_test( else: with open(bench_args.prompt_filename, "r") as pf: prompt_pool = json.load(pf) - prompt_dict = prompt_pool[str(bench_args.input_len[0])] - for index in range(bench_args.batch_size[0]): - custom_prompts.append(prompt_dict[str(index)]) + prompt_dict = None + if str(bench_args.input_len[0]) in prompt_pool: + prompt_dict = prompt_pool[str(bench_args.input_len[0])] + else: + for key in prompt_pool.keys(): + if key in server_args.model_path: + prompt_dict = prompt_pool[key][str(bench_args.input_len[0])] + break + if prompt_dict is None: + rank_print( + f"Custom prompt file {bench_args.prompt_filename} does not contain prompts for {server_args.model_path} and" + f"input length {bench_args.input_len[0]}. Using dummy data..." + ) + else: + for index in range(bench_args.batch_size[0]): + if isinstance(prompt_dict, str): + custom_prompts.append(prompt_dict) + elif str(index) in prompt_dict: + custom_prompts.append(prompt_dict[str(index)]) + else: + rank_print( + f"Custom prompt file {bench_args.prompt_filename} does not contain prompt for batch {index}. Using dummy data..." + ) custom_inputs = [tokenizer.encode(p.strip()) for p in custom_prompts] diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 8de127c6c1c2..b4352e24f344 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,25 +1,12 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - all_close_1d, - convert_to_channelwise, - per_tensor_dequantize, - requantize_with_max_scale, -) from sglang.srt.cpu_utils import _process_weight_after_loading, cpu_has_amx_support from sglang.srt.distributed import get_tensor_model_parallel_world_size @@ -53,6 +40,19 @@ print_warning_once, set_weight_attrs, ) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + convert_to_channelwise, + per_tensor_dequantize, + requantize_with_max_scale, +) if cpu_has_amx_support(): import sgl_kernel.cpu @@ -136,9 +136,8 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention # Avoid circular import - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from vllm.attention.layer import Attention # Avoid circular import if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): @@ -154,6 +153,40 @@ def get_scaled_act_names(self) -> List[str]: return [] +def requantize_with_max_scale_cpu( + weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int] +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = ( + weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) + # weight[start:end, :], _ = ops.scaled_fp8_quant( + # weight_dq, max_w_scale) + weight[start:end, :] = torch.clamp( + (weight_dq / max_w_scale), + torch.finfo(torch.float8_e4m3fn).min, + torch.finfo(torch.float8_e4m3fn).max, + ).to(torch.float8_e4m3fn) + start = end + + return max_w_scale, weight + + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. Supports loading FP8 checkpoints with static weight scale and @@ -366,14 +399,14 @@ def process_weights_after_loading(self, layer: Module) -> None: if input_scale is not None: layer.input_scale = Parameter(input_scale, requires_grad=False) - weight_scale, weight = requantize_with_max_scale( + weight_scale, weight = requantize_with_max_scale_cpu( weight=weight, weight_scale=weight_scale, logical_widths=layer.logical_widths, ) # Update layer with new values. - layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight = Parameter(weight.t().contiguous(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter( @@ -422,6 +455,23 @@ def apply( input_scale=None, bias=bias, ) + if self.quant_config.activation_scheme == "static": + q_input = torch.ops.quantized_decomposed.quantize_per_tensor( + input=x, + scale=layer.input_scale, + zero_point=0, + quant_min=int(torch.finfo(torch.float8_e4m3fn).min), + quant_max=int(torch.finfo(torch.float8_e4m3fn).max), + dtype=torch.float8_e4m3fn, + ) + return torch._scaled_mm( + q_input, + layer.weight, + bias=bias, + out_dtype=x.dtype, + scale_a=layer.input_scale, + scale_b=layer.weight_scale, + ) return apply_fp8_linear( input=x,