From 0eb6808e9f295daad35c0ecc90a989bb55b98ee1 Mon Sep 17 00:00:00 2001 From: gcanlin Date: Wed, 7 Jan 2026 14:37:53 +0000 Subject: [PATCH] Sperate TritonOpBase Signed-off-by: gcanlin --- vllm/model_executor/custom_op.py | 202 +++++------------- vllm/model_executor/layers/fla/ops/chunk.py | 6 +- .../layers/fla/ops/fused_recurrent.py | 6 +- .../layers/mamba/ops/causal_conv1d.py | 10 +- vllm/model_executor/models/qwen3_next.py | 6 +- vllm/model_executor/triton_op.py | 136 ++++++++++++ 6 files changed, 200 insertions(+), 166 deletions(-) create mode 100644 vllm/model_executor/triton_op.py diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index 81363906b0dd..66250f816f45 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -2,30 +2,21 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any - import torch.nn as nn from vllm.config import get_cached_compilation_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.triton_utils import HAS_TRITON logger = init_logger(__name__) -class CustomOpBase: +class CustomOp(nn.Module): """ - Base class for custom operators, including torch and triton operators. - CustomOpBase mainly maintains the basic funtions, including the __new__, - __init__, registry, registry_oot and forward, these 4 basic functions are - not recommanded to be overrided. And the others should be override by the - specific subclasses. + Base class for custom ops. + Dispatches the forward method to the appropriate backend. """ - op_registry: dict[str, Any] = {} - op_registry_oot: dict[str, Any] = {} - def __new__(cls, *args, **kwargs): try: op_name = cls.__name__ @@ -33,9 +24,7 @@ def __new__(cls, *args, **kwargs): raise TypeError( f"Cannot instantiate '{cls.__name__}': its 'name' attribute " f"was not set, possibly because it was not decorated with " - f"@CustomOPBase.register (or @CustomOPBase.register, " - "@CustomTritonOp.register), or it's the CustomOPBase base " - "class itself." + f"@CustomOp.register, or it's the CustomOp base class itself." ) from None if op_name not in cls.op_registry_oot: @@ -50,91 +39,13 @@ def __new__(cls, *args, **kwargs): return super().__new__(op_cls_to_instantiate) def __init__(self, enforce_enable: bool = False): + super().__init__() self._enforce_enable = enforce_enable self._forward_method = self.dispatch_forward() def forward(self, *args, **kwargs): return self._forward_method(*args, **kwargs) - def forward_native(self, *args, **kwargs): - raise NotImplementedError - - def forward_cuda(self, *args, **kwargs): - raise NotImplementedError - - def forward_hip(self, *args, **kwargs): - raise NotImplementedError - - def forward_xpu(self, *args, **kwargs): - raise NotImplementedError - - def forward_cpu(self, *args, **kwargs): - raise NotImplementedError - - def forward_tpu(self, *args, **kwargs): - raise NotImplementedError - - def forward_oot(self, *args, **kwargs): - raise NotImplementedError - - def dispatch_forward(self): - raise NotImplementedError - - @classmethod - def enabled(cls) -> bool: - raise NotImplementedError - - # Decorator to register custom ops. - @classmethod - def register(cls, name: str): - def decorator(op_cls): - assert name not in cls.op_registry, f"Duplicate op name: {name}" - op_cls.name = name - cls.op_registry[name] = op_cls - return op_cls - - return decorator - - @classmethod - def register_oot(cls, _decorated_op_cls=None, name: str | None = None): - def decorator(op_cls): - reg_name = name if name is not None else cls.__name__ - assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" - op_cls.name = reg_name - cls.op_registry_oot[reg_name] = op_cls - return op_cls - - if _decorated_op_cls is None: - # Called with parentheses: @CustomOPBase.register_oot() - # or @CustomOPBase.register_oot(name="...") - # CustomOPBase could be replaced by CustomOP and CustomTritonOp - # So, _decorated_op_cls is None. - # We return the actual decorator function. - return decorator - elif isinstance(_decorated_op_cls, type): # Check if it's a class - # Called without parentheses: @CustomOPBase.register_oot - # CustomOPBase could be replaced by CustomOP and CustomTritonOp - # The first argument is the class itself. - # We call the 'decorator' function immediately with the class. - return decorator(_decorated_op_cls) - else: - # Handle other unexpected cases if necessary - raise TypeError("Decorator can only be applied to classes.") - - -class CustomOp(nn.Module, CustomOpBase): - """ - Base class for torch custom ops. - Impletments and dispatches the forward method to the appropriate backend. - """ - - def __init__(self, enforce_enable: bool = False): - nn.Module.__init__(self) - CustomOpBase.__init__(self, enforce_enable=enforce_enable) - - def forward(self, *args, **kwargs): - return self._forward_method(*args, **kwargs) - def forward_native(self, *args, **kwargs): """PyTorch-native implementation of the forward method. This method is optional. If implemented, it can be used with compilers @@ -235,67 +146,54 @@ def default_on() -> bool: return not count_none > 0 or count_all > 0 + # Dictionary of all custom ops (classes, indexed by registered name). + # To check if an op with a name is enabled, call .enabled() on the class. + # Examples: + # - MyOp.enabled() + # - op_registry["my_op"].enabled() + op_registry: dict[str, type["CustomOp"]] = {} + op_registry_oot: dict[str, type["CustomOp"]] = {} -class CustomTritonOp(CustomOpBase): - """ - Base class for triton custom ops. - Impletments and dispatches the forward method to the appropriate backend. - """ - - def __init__(self, enforce_enable: bool = False): - super().__init__(enforce_enable=enforce_enable) - - def forward(self, *args, **kwargs): - return self._forward_method(*args, **kwargs) - + # Decorator to register custom ops. @classmethod - def enabled(cls) -> bool: - return HAS_TRITON - - def forward_cuda(self, *args, **kwargs): - raise NotImplementedError - - def forward_hip(self, *args, **kwargs): - # By default, we assume that HIP ops are compatible with CUDA ops. - return self.forward_cuda(*args, **kwargs) - - def forward_xpu(self, *args, **kwargs): - # By default, we assume that XPU ops are compatible with CUDA ops. - # NOTE: This is a placeholder for future extensions. - return self.forward_cuda(*args, **kwargs) - - def forward_cpu(self, *args, **kwargs): - # By default, we assume that CPU ops are compatible with CUDA ops. - # NOTE: This is a placeholder for future extensions. - return self.forward_cuda(*args, **kwargs) - - def forward_tpu(self, *args, **kwargs): - # By default, we assume that CPU ops are compatible with CUDA ops. - # NOTE: This is a placeholder for future extensions. - return self.forward_cuda(*args, **kwargs) + def register(cls, name: str): + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls - def forward_oot(self, *args, **kwargs): - # By default, we assume that CPU ops are compatible with CUDA ops. - # NOTE: This is a placeholder for future extensions. - return self.forward_cuda(*args, **kwargs) + return decorator - def dispatch_forward(self): - enabled = self._enforce_enable or self.enabled() - if not enabled: - raise RuntimeError( - f"TritonOp {self.__class__.__name__} is disabled or " - "Triton not available" - ) + # Decorator to register out-of-tree(oot) custom ops. + # For OOT custom ops: + # if in-tree layer class is registered with an oot_custom_op layer, + # the oot_custom_op layer will be used instead. + # Example: + # - @UnquantizedFusedMoEMethod.register_oot + # class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod) + # or + # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod") + @classmethod + def register_oot(cls, _decorated_op_cls=None, name: str | None = None): + def decorator(op_cls): + reg_name = name if name is not None else cls.__name__ + assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" + op_cls.name = reg_name + cls.op_registry_oot[reg_name] = op_cls + return op_cls - if current_platform.is_rocm(): - return self.forward_hip - elif current_platform.is_cpu(): - return self.forward_cpu - elif current_platform.is_tpu(): - return self.forward_tpu - elif current_platform.is_xpu(): - return self.forward_xpu - elif current_platform.is_out_of_tree(): - return self.forward_oot + if _decorated_op_cls is None: + # Called with parentheses: @CustomOP.register_oot() + # or @CustomOP.register_oot(name="...") + # So, _decorated_op_cls is None. + # We return the actual decorator function. + return decorator + elif isinstance(_decorated_op_cls, type): # Check if it's a class + # Called without parentheses: @CustomOP.register_oot + # The first argument is the class itself. + # We call the 'decorator' function immediately with the class. + return decorator(_decorated_op_cls) else: - return self.forward_cuda + # Handle other unexpected cases if necessary + raise TypeError("Decorator can only be applied to classes.") diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 321737092194..51bad0b1c5b9 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -12,7 +12,7 @@ import torch from einops import rearrange -from vllm.model_executor.custom_op import CustomTritonOp +from vllm.model_executor.triton_op import TritonOpBase from .chunk_delta_h import chunk_gated_delta_rule_fwd_h from .chunk_o import chunk_fwd_o @@ -242,7 +242,7 @@ def chunk_gated_delta_rule( return o, final_state -@CustomTritonOp.register("chunk_gated_delta_rule") -class ChunkGatedDeltaRule(CustomTritonOp): +@TritonOpBase.register("chunk_gated_delta_rule") +class ChunkGatedDeltaRule(TritonOpBase): def forward_cuda(self, *args, **kwargs): return chunk_gated_delta_rule(*args, **kwargs) diff --git a/vllm/model_executor/layers/fla/ops/fused_recurrent.py b/vllm/model_executor/layers/fla/ops/fused_recurrent.py index b7e84ba81339..a2f50ed0bdc4 100644 --- a/vllm/model_executor/layers/fla/ops/fused_recurrent.py +++ b/vllm/model_executor/layers/fla/ops/fused_recurrent.py @@ -10,7 +10,7 @@ import torch -from vllm.model_executor.custom_op import CustomTritonOp +from vllm.model_executor.triton_op import TritonOpBase from vllm.triton_utils import tl, triton from .op import exp @@ -391,7 +391,7 @@ def fused_recurrent_gated_delta_rule( return o, final_state -@CustomTritonOp.register("fused_recurrent_gated_delta_rule") -class FusedRecurrentGatedDeltaRule(CustomTritonOp): +@TritonOpBase.register("fused_recurrent_gated_delta_rule") +class FusedRecurrentGatedDeltaRule(TritonOpBase): def forward_cuda(self, *args, **kwargs): return fused_recurrent_gated_delta_rule(*args, **kwargs) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 1f939664de6d..ba55101e3f9b 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -9,7 +9,7 @@ import torch from vllm.attention.backends.utils import PAD_SLOT_ID -from vllm.model_executor.custom_op import CustomTritonOp +from vllm.model_executor.triton_op import TritonOpBase from vllm.triton_utils import tl, triton @@ -746,8 +746,8 @@ def grid(META): return out.to(original_x_dtype) -@CustomTritonOp.register("causal_conv1d") -class CausalConv1d(CustomTritonOp): +@TritonOpBase.register("causal_conv1d") +class CausalConv1d(TritonOpBase): def forward_cuda(self, *args, **kwargs): return causal_conv1d_fn(*args, **kwargs) @@ -1247,7 +1247,7 @@ def grid(META): return out.to(original_x_dtype) -@CustomTritonOp.register("causal_conv1d_update") -class CausalConv1dUpdate(CustomTritonOp): +@TritonOpBase.register("causal_conv1d_update") +class CausalConv1dUpdate(TritonOpBase): def forward_cuda(self, *args, **kwargs): return causal_conv1d_update(*args, **kwargs) diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 6a45b52705cf..97d351f7a967 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -30,7 +30,6 @@ ) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger -from vllm.model_executor.custom_op import CustomTritonOp from vllm.model_executor.layers.fla.ops import ( ChunkGatedDeltaRule, FusedRecurrentGatedDeltaRule, @@ -70,6 +69,7 @@ ) from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.utils import sequence_parallel_chunk +from vllm.model_executor.triton_op import TritonOpBase from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -1406,7 +1406,7 @@ def fused_gdn_gating( return g, beta_output -@CustomTritonOp.register("fused_gdn_gating") -class FusedGDNGating(CustomTritonOp): +@TritonOpBase.register("fused_gdn_gating") +class FusedGDNGating(TritonOpBase): def forward_cuda(self, *args, **kwargs): return fused_gdn_gating(*args, **kwargs) diff --git a/vllm/model_executor/triton_op.py b/vllm/model_executor/triton_op.py new file mode 100644 index 000000000000..fb5b529ba212 --- /dev/null +++ b/vllm/model_executor/triton_op.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class TritonOpBase: + """ + Base class for Triton custom ops. + Dispatches the forward method to the appropriate hardware backend. + Similar to CustomOp but specifically for Triton-based operations. + """ + + def __new__(cls, *args, **kwargs): + try: + op_name = cls.__name__ + except AttributeError: + raise TypeError( + f"Cannot instantiate '{cls.__name__}': its 'name' attribute " + f"was not set, possibly because it was not decorated with " + f"@TritonOpBase.register, or it's the TritonOpBase base class " + f"itself." + ) from None + + if op_name not in cls.op_registry_oot: + op_cls_to_instantiate = cls + else: + op_cls_to_instantiate = cls.op_registry_oot[op_name] + logger.debug( + "Instantiating triton op: %s using %s", + op_name, + str(op_cls_to_instantiate), + ) + return super().__new__(op_cls_to_instantiate) + + def __init__(self): + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + # By default, we assume that HIP ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_xpu(self, *args, **kwargs): + # By default, we assume that XPU ops are compatible with CUDA ops. + # NOTE: This is a placeholder for future extensions. + return self.forward_cuda(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + # By default, we assume that CPU ops are compatible with CUDA ops. + # NOTE: This is a placeholder for future extensions. + return self.forward_cuda(*args, **kwargs) + + def forward_tpu(self, *args, **kwargs): + # By default, we assume that CPU ops are compatible with CUDA ops. + # NOTE: This is a placeholder for future extensions. + return self.forward_cuda(*args, **kwargs) + + def forward_oot(self, *args, **kwargs): + # By default, we assume that CPU ops are compatible with CUDA ops. + # NOTE: This is a placeholder for future extensions. + return self.forward_cuda(*args, **kwargs) + + def dispatch_forward(self): + enabled = self.enabled() + if not enabled: + raise RuntimeError("Triton is not available.") + + if current_platform.is_rocm(): + return self.forward_hip + elif current_platform.is_cpu(): + return self.forward_cpu + elif current_platform.is_tpu(): + return self.forward_tpu + elif current_platform.is_xpu(): + return self.forward_xpu + elif current_platform.is_out_of_tree(): + return self.forward_oot + else: + return self.forward_cuda + + @classmethod + def enabled(cls) -> bool: + """Returns True if Triton is available.""" + # Import here to avoid circular imports + from vllm.triton_utils import HAS_TRITON + + return HAS_TRITON + + # Dictionary of all triton ops (classes, indexed by registered name). + op_registry: dict[str, type["TritonOpBase"]] = {} + op_registry_oot: dict[str, type["TritonOpBase"]] = {} + + # Decorator to register triton ops. + @classmethod + def register(cls, name: str): + def decorator(op_cls): + assert name not in cls.op_registry, f"Duplicate op name: {name}" + op_cls.name = name + cls.op_registry[name] = op_cls + return op_cls + + return decorator + + # Decorator to register out-of-tree(oot) triton ops. + # Example: + # - @TritonOp.register_oot + # class OOTTritonOp(TritonOp) + # or + # - TritonOpBase.register_oot(name="TritonOp") + @classmethod + def register_oot(cls, _decorated_op_cls=None, name: str | None = None): + def decorator(op_cls): + reg_name = name if name is not None else cls.__name__ + assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" + op_cls.name = reg_name + cls.op_registry_oot[reg_name] = op_cls + return op_cls + + if _decorated_op_cls is None: + # Called with parentheses: @TritonOpBase.register_oot() + # or @TritonOpBase.register_oot(name="...") + return decorator + elif isinstance(_decorated_op_cls, type): + # Called without parentheses: @TritonOp.register_oot + return decorator(_decorated_op_cls) + else: + raise TypeError("Decorator can only be applied to classes.")