Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 50 additions & 152 deletions vllm/model_executor/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,29 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


from typing import Any

import torch.nn as nn

from vllm.config import get_cached_compilation_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON

logger = init_logger(__name__)


class CustomOpBase:
class CustomOp(nn.Module):
"""
Base class for custom operators, including torch and triton operators.
CustomOpBase mainly maintains the basic funtions, including the __new__,
__init__, registry, registry_oot and forward, these 4 basic functions are
not recommanded to be overrided. And the others should be override by the
specific subclasses.
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""

op_registry: dict[str, Any] = {}
op_registry_oot: dict[str, Any] = {}

def __new__(cls, *args, **kwargs):
try:
op_name = cls.__name__
except AttributeError:
raise TypeError(
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
f"was not set, possibly because it was not decorated with "
f"@CustomOPBase.register (or @CustomOPBase.register, "
"@CustomTritonOp.register), or it's the CustomOPBase base "
"class itself."
f"@CustomOp.register, or it's the CustomOp base class itself."
) from None

if op_name not in cls.op_registry_oot:
Expand All @@ -50,91 +39,13 @@ def __new__(cls, *args, **kwargs):
return super().__new__(op_cls_to_instantiate)

def __init__(self, enforce_enable: bool = False):
super().__init__()
self._enforce_enable = enforce_enable
self._forward_method = self.dispatch_forward()

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

def forward_native(self, *args, **kwargs):
raise NotImplementedError

def forward_cuda(self, *args, **kwargs):
raise NotImplementedError

def forward_hip(self, *args, **kwargs):
raise NotImplementedError

def forward_xpu(self, *args, **kwargs):
raise NotImplementedError

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError

def forward_tpu(self, *args, **kwargs):
raise NotImplementedError

def forward_oot(self, *args, **kwargs):
raise NotImplementedError

def dispatch_forward(self):
raise NotImplementedError

@classmethod
def enabled(cls) -> bool:
raise NotImplementedError

# Decorator to register custom ops.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls

return decorator

@classmethod
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
def decorator(op_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}"
op_cls.name = reg_name
cls.op_registry_oot[reg_name] = op_cls
return op_cls

if _decorated_op_cls is None:
# Called with parentheses: @CustomOPBase.register_oot()
# or @CustomOPBase.register_oot(name="...")
# CustomOPBase could be replaced by CustomOP and CustomTritonOp
# So, _decorated_op_cls is None.
# We return the actual decorator function.
return decorator
elif isinstance(_decorated_op_cls, type): # Check if it's a class
# Called without parentheses: @CustomOPBase.register_oot
# CustomOPBase could be replaced by CustomOP and CustomTritonOp
# The first argument is the class itself.
# We call the 'decorator' function immediately with the class.
return decorator(_decorated_op_cls)
else:
# Handle other unexpected cases if necessary
raise TypeError("Decorator can only be applied to classes.")


class CustomOp(nn.Module, CustomOpBase):
"""
Base class for torch custom ops.
Impletments and dispatches the forward method to the appropriate backend.
"""

def __init__(self, enforce_enable: bool = False):
nn.Module.__init__(self)
CustomOpBase.__init__(self, enforce_enable=enforce_enable)

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
Expand Down Expand Up @@ -235,67 +146,54 @@ def default_on() -> bool:

return not count_none > 0 or count_all > 0

# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"]] = {}
op_registry_oot: dict[str, type["CustomOp"]] = {}

class CustomTritonOp(CustomOpBase):
"""
Base class for triton custom ops.
Impletments and dispatches the forward method to the appropriate backend.
"""

def __init__(self, enforce_enable: bool = False):
super().__init__(enforce_enable=enforce_enable)

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

# Decorator to register custom ops.
@classmethod
def enabled(cls) -> bool:
return HAS_TRITON

def forward_cuda(self, *args, **kwargs):
raise NotImplementedError

def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)

def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with CUDA ops.
# NOTE: This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)

def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
# NOTE: This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)

def forward_tpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
# NOTE: This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)
def register(cls, name: str):
def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}"
op_cls.name = name
cls.op_registry[name] = op_cls
return op_cls

def forward_oot(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
# NOTE: This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)
return decorator

def dispatch_forward(self):
enabled = self._enforce_enable or self.enabled()
if not enabled:
raise RuntimeError(
f"TritonOp {self.__class__.__name__} is disabled or "
"Triton not available"
)
# Decorator to register out-of-tree(oot) custom ops.
# For OOT custom ops:
# if in-tree layer class is registered with an oot_custom_op layer,
# the oot_custom_op layer will be used instead.
# Example:
# - @UnquantizedFusedMoEMethod.register_oot
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
# or
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
@classmethod
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
def decorator(op_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}"
op_cls.name = reg_name
cls.op_registry_oot[reg_name] = op_cls
return op_cls

if current_platform.is_rocm():
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_tpu():
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_out_of_tree():
return self.forward_oot
if _decorated_op_cls is None:
# Called with parentheses: @CustomOP.register_oot()
# or @CustomOP.register_oot(name="...")
# So, _decorated_op_cls is None.
# We return the actual decorator function.
return decorator
elif isinstance(_decorated_op_cls, type): # Check if it's a class
# Called without parentheses: @CustomOP.register_oot
# The first argument is the class itself.
# We call the 'decorator' function immediately with the class.
return decorator(_decorated_op_cls)
else:
return self.forward_cuda
# Handle other unexpected cases if necessary
raise TypeError("Decorator can only be applied to classes.")
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from einops import rearrange

from vllm.model_executor.custom_op import CustomTritonOp
from vllm.model_executor.triton_op import TritonOpBase

from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
Expand Down Expand Up @@ -242,7 +242,7 @@ def chunk_gated_delta_rule(
return o, final_state


@CustomTritonOp.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(CustomTritonOp):
@TritonOpBase.register("chunk_gated_delta_rule")
class ChunkGatedDeltaRule(TritonOpBase):
def forward_cuda(self, *args, **kwargs):
return chunk_gated_delta_rule(*args, **kwargs)
6 changes: 3 additions & 3 deletions vllm/model_executor/layers/fla/ops/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import torch

from vllm.model_executor.custom_op import CustomTritonOp
from vllm.model_executor.triton_op import TritonOpBase
from vllm.triton_utils import tl, triton

from .op import exp
Expand Down Expand Up @@ -391,7 +391,7 @@ def fused_recurrent_gated_delta_rule(
return o, final_state


@CustomTritonOp.register("fused_recurrent_gated_delta_rule")
class FusedRecurrentGatedDeltaRule(CustomTritonOp):
@TritonOpBase.register("fused_recurrent_gated_delta_rule")
class FusedRecurrentGatedDeltaRule(TritonOpBase):
def forward_cuda(self, *args, **kwargs):
return fused_recurrent_gated_delta_rule(*args, **kwargs)
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch

from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.custom_op import CustomTritonOp
from vllm.model_executor.triton_op import TritonOpBase
from vllm.triton_utils import tl, triton


Expand Down Expand Up @@ -746,8 +746,8 @@ def grid(META):
return out.to(original_x_dtype)


@CustomTritonOp.register("causal_conv1d")
class CausalConv1d(CustomTritonOp):
@TritonOpBase.register("causal_conv1d")
class CausalConv1d(TritonOpBase):
def forward_cuda(self, *args, **kwargs):
return causal_conv1d_fn(*args, **kwargs)

Expand Down Expand Up @@ -1247,7 +1247,7 @@ def grid(META):
return out.to(original_x_dtype)


@CustomTritonOp.register("causal_conv1d_update")
class CausalConv1dUpdate(CustomTritonOp):
@TritonOpBase.register("causal_conv1d_update")
class CausalConv1dUpdate(TritonOpBase):
def forward_cuda(self, *args, **kwargs):
return causal_conv1d_update(*args, **kwargs)
6 changes: 3 additions & 3 deletions vllm/model_executor/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomTritonOp
from vllm.model_executor.layers.fla.ops import (
ChunkGatedDeltaRule,
FusedRecurrentGatedDeltaRule,
Expand Down Expand Up @@ -70,6 +69,7 @@
)
from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.model_executor.triton_op import TritonOpBase
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
Expand Down Expand Up @@ -1406,7 +1406,7 @@ def fused_gdn_gating(
return g, beta_output


@CustomTritonOp.register("fused_gdn_gating")
class FusedGDNGating(CustomTritonOp):
@TritonOpBase.register("fused_gdn_gating")
class FusedGDNGating(TritonOpBase):
def forward_cuda(self, *args, **kwargs):
return fused_gdn_gating(*args, **kwargs)
Loading