Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,3 +634,44 @@ def unified_attention_with_output_fake(
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,
)


def creat_attn_output_buffer(
output_shape: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
need_init: bool,
) -> torch.Tensor:
# Convert tensor to tuple for shape
shape = tuple(output_shape.tolist())
if need_init:
# Avoid output contains NaNs, which causes numerical issue during
# profile run. See: https://github.com/vllm-project/vllm/pull/19784
return torch.zeros(shape,
dtype=dtype,
device=device)
else:
return torch.empty(shape,
dtype=dtype,
device=device)


def creat_attn_output_buffer_fake(
output_shape: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
need_init: bool,
) -> torch.Tensor:
shape = tuple(output_shape.tolist())
return torch.empty(shape,
dtype=dtype,
device=device)


direct_register_custom_op(
op_name="creat_attn_output_buffer",
op_func=creat_attn_output_buffer,
mutates_args=[],
fake_impl=creat_attn_output_buffer_fake,
dispatch_key=current_platform.dispatch_key,
)
19 changes: 15 additions & 4 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2554,10 +2554,21 @@ def direct_register_custom_op(
import torch._custom_op.impl
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)

try:
my_lib.define(op_name + schema_str, tags=tags)
my_lib.impl(op_name, op_func, dispatch_key=dispatch_key)
if fake_impl is not None:
my_lib._register_fake(op_name, fake_impl)
except RuntimeError as e:
# Handle duplicate registration gracefully
if "Tried to register an operator" in str(e) and "multiple times" in str(e):
# The operation is already registered, which is fine
# This can happen when modules are imported multiple times
pass
else:
# Re-raise other RuntimeErrors as they indicate real problems
raise


def resolve_obj_by_qualname(qualname: str) -> Any:
Expand Down