diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 544a72052442..d5cda0f62900 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -634,3 +634,44 @@ def unified_attention_with_output_fake( dispatch_key=current_platform.dispatch_key, tags=tag_cudagraph_unsafe, ) + + +def creat_attn_output_buffer( + output_shape: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + need_init: bool, +) -> torch.Tensor: + # Convert tensor to tuple for shape + shape = tuple(output_shape.tolist()) + if need_init: + # Avoid output contains NaNs, which causes numerical issue during + # profile run. See: https://github.com/vllm-project/vllm/pull/19784 + return torch.zeros(shape, + dtype=dtype, + device=device) + else: + return torch.empty(shape, + dtype=dtype, + device=device) + + +def creat_attn_output_buffer_fake( + output_shape: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + need_init: bool, +) -> torch.Tensor: + shape = tuple(output_shape.tolist()) + return torch.empty(shape, + dtype=dtype, + device=device) + + +direct_register_custom_op( + op_name="creat_attn_output_buffer", + op_func=creat_attn_output_buffer, + mutates_args=[], + fake_impl=creat_attn_output_buffer_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 022e35a399c5..49e706d202ec 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -2554,10 +2554,21 @@ def direct_register_custom_op( import torch._custom_op.impl schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib - my_lib.define(op_name + schema_str, tags=tags) - my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) - if fake_impl is not None: - my_lib._register_fake(op_name, fake_impl) + + try: + my_lib.define(op_name + schema_str, tags=tags) + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) + except RuntimeError as e: + # Handle duplicate registration gracefully + if "Tried to register an operator" in str(e) and "multiple times" in str(e): + # The operation is already registered, which is fine + # This can happen when modules are imported multiple times + pass + else: + # Re-raise other RuntimeErrors as they indicate real problems + raise def resolve_obj_by_qualname(qualname: str) -> Any: