Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions python/sglang/bench_one_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,29 @@ def latency_test(
else:
with open(bench_args.prompt_filename, "r") as pf:
prompt_pool = json.load(pf)
prompt_dict = prompt_pool[str(bench_args.input_len[0])]
for index in range(bench_args.batch_size[0]):
custom_prompts.append(prompt_dict[str(index)])
prompt_dict = None
if str(bench_args.input_len[0]) in prompt_pool:
prompt_dict = prompt_pool[str(bench_args.input_len[0])]
else:
for key in prompt_pool.keys():
if key in server_args.model_path:
prompt_dict = prompt_pool[key][str(bench_args.input_len[0])]
break
if prompt_dict is None:
rank_print(
f"Custom prompt file {bench_args.prompt_filename} does not contain prompts for {server_args.model_path} and"
f"input length {bench_args.input_len[0]}. Using dummy data..."
)
else:
for index in range(bench_args.batch_size[0]):
if isinstance(prompt_dict, str):
custom_prompts.append(prompt_dict)
elif str(index) in prompt_dict:
custom_prompts.append(prompt_dict[str(index)])
else:
rank_print(
f"Custom prompt file {bench_args.prompt_filename} does not contain prompt for batch {index}. Using dummy data..."
)

custom_inputs = [tokenizer.encode(p.strip()) for p in custom_prompts]

Expand Down
86 changes: 68 additions & 18 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py

import logging
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
convert_to_channelwise,
per_tensor_dequantize,
requantize_with_max_scale,
)

from sglang.srt.cpu_utils import _process_weight_after_loading, cpu_has_amx_support
from sglang.srt.distributed import get_tensor_model_parallel_world_size
Expand Down Expand Up @@ -53,6 +40,19 @@
print_warning_once,
set_weight_attrs,
)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
convert_to_channelwise,
per_tensor_dequantize,
requantize_with_max_scale,
)

if cpu_has_amx_support():
import sgl_kernel.cpu
Expand Down Expand Up @@ -136,9 +136,8 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import

from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from vllm.attention.layer import Attention # Avoid circular import

if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
Expand All @@ -154,6 +153,40 @@ def get_scaled_act_names(self) -> List[str]:
return []


def requantize_with_max_scale_cpu(
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()

# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
)

# If unfused checkpoint, need requanize with the single scale.
if unfused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
# weight[start:end, :], _ = ops.scaled_fp8_quant(
# weight_dq, max_w_scale)
weight[start:end, :] = torch.clamp(
(weight_dq / max_w_scale),
torch.finfo(torch.float8_e4m3fn).min,
torch.finfo(torch.float8_e4m3fn).max,
).to(torch.float8_e4m3fn)
start = end

return max_w_scale, weight


class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
Expand Down Expand Up @@ -366,14 +399,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
if input_scale is not None:
layer.input_scale = Parameter(input_scale, requires_grad=False)

weight_scale, weight = requantize_with_max_scale(
weight_scale, weight = requantize_with_max_scale_cpu(
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)

# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight = Parameter(weight.t().contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(
Expand Down Expand Up @@ -422,6 +455,23 @@ def apply(
input_scale=None,
bias=bias,
)
if self.quant_config.activation_scheme == "static":
q_input = torch.ops.quantized_decomposed.quantize_per_tensor(
input=x,
scale=layer.input_scale,
zero_point=0,
quant_min=int(torch.finfo(torch.float8_e4m3fn).min),
quant_max=int(torch.finfo(torch.float8_e4m3fn).max),
dtype=torch.float8_e4m3fn,
)
return torch._scaled_mm(
q_input,
layer.weight,
bias=bias,
out_dtype=x.dtype,
scale_a=layer.input_scale,
scale_b=layer.weight_scale,
)

return apply_fp8_linear(
input=x,
Expand Down
Loading