From cc416f72473e832ea02a69c442f6a2be040f3368 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 24 Jul 2025 22:22:42 +0800 Subject: [PATCH 1/9] [feat] block size selection --- .../backend/cuda/bmm/__init__.py | 2 +- .../tritontemplate/backend/cuda/bmm/bmm.py | 5 +++- .../backend/cuda/gemm/__init__.py | 2 +- .../tritontemplate/backend/cuda/gemm/gemm.py | 9 +++---- .../backend/cuda/layernorm/__init__.py | 2 +- .../backend/cuda/layernorm/layernorm.py | 5 +++- .../backend/cuda/softmax/__init__.py | 2 +- .../backend/cuda/softmax/softmax.py | 5 +++- .../backend/cuda/transpose/__init__.py | 4 +-- .../backend/cuda/transpose/transpose_0213.py | 5 +++- .../backend/cuda/transpose/transpose_10.py | 3 +++ .../backend/cuda/utils/__init__.py | 1 + .../python/tritontemplate/compiler/base.py | 26 +++++++++++++++++- .../python/tritontemplate/compiler/kernel.py | 14 +++------- .../tritontemplate/compiler/ops/bmm/bmm.py | 12 ++++++--- .../tritontemplate/compiler/ops/gemm/gemm.py | 13 +++++---- .../compiler/ops/layernorm/layernorm.py | 14 +++++++--- .../compiler/ops/softmax/softmax.py | 10 ++++--- .../compiler/ops/transpose/transpose.py | 18 ++++++++----- .../python/tritontemplate/compiler/utils.py | 27 ++++++++++--------- .../tritontemplate/testing/cuda/test_gemm.py | 2 ++ .../testing/cuda/test_layernorm.py | 8 +++--- 22 files changed, 123 insertions(+), 66 deletions(-) diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py index 5f8c6587a..0302062fd 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py @@ -1 +1 @@ -from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm \ No newline at end of file +from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm, gen_smem_size_bmm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py index 38a1813da..b8e59ca9e 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -1,12 +1,15 @@ import triton import triton.language as tl -def gen_grid_bmm(batch_size,M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): +def gen_grid_bmm(batch_size:int,M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int): """ Generates the grid for a Batch GEMM kernel. """ return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) +def gen_smem_size_bmm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int): + return (BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages + @triton.jit def bmm_bias( # Pointers to matrices diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py index ddc868bd3..62314bc6e 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py @@ -1 +1 @@ -from tritontemplate.backend.cuda.gemm.gemm import gemm,gemm_bias,gen_grid_gemm +from tritontemplate.backend.cuda.gemm.gemm import gemm,gemm_bias,gen_grid_gemm,gen_smem_size_gemm diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py index f9b3d1b52..f2fd76255 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/gemm.py @@ -2,18 +2,15 @@ import triton.language as tl from tritontemplate.backend.cuda.utils.activation import * -def gen_grid_gemm(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): +def gen_grid_gemm(M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int): """ Generates the grid for a GEMM kernel. """ return ( triton.cdiv(M, BLOCK_SIZE_M)*triton.cdiv(N, BLOCK_SIZE_N),1,1) -# smem_size=val*dtype_size*num_stage -smem_demand_per_stage ={ - 'gemm_bias': 128*128*2, - 'gemm': 128*128*2, -} +def gen_smem_size_gemm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int): + return (BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages @triton.jit def gemm_bias( diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py index 9bcd2adac..c745404cb 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py @@ -1 +1 @@ -from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm \ No newline at end of file +from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm,gen_smem_size_layernorm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py index 9c7d8541b..1ee67f534 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/layernorm.py @@ -1,10 +1,13 @@ import triton import triton.language as tl -def gen_grid_layernorm(M,BLOCK_SIZE_M): +def gen_grid_layernorm(M:int,BLOCK_SIZE_M:int): grid = (triton.cdiv(M, BLOCK_SIZE_M),1,1) return grid +def gen_smem_size_layernorm(BLOCK_SIZE_M:int, BLOCK_SIZE_N:int,num_stages:int): + return (BLOCK_SIZE_N*BLOCK_SIZE_M)*num_stages + @triton.jit def layernorm(x_ptr,y_ptr,M:tl.constexpr,N:tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr,stride_y0:tl.constexpr,stride_y1:tl.constexpr,BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,eps:tl.constexpr=1e-5): ''' diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py index 2fc667265..a636a6c92 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py @@ -1 +1 @@ -from tritontemplate.backend.cuda.softmax.softmax import softmax,online_softmax,gen_grid_softmax \ No newline at end of file +from tritontemplate.backend.cuda.softmax.softmax import softmax,online_softmax,gen_grid_softmax,gen_smem_size_softmax \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py index 393cbe217..ee88c33ed 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/softmax.py @@ -1,13 +1,16 @@ import triton import triton.language as tl -def gen_grid_softmax(M, BLOCK_SIZE_M): +def gen_grid_softmax(M:int, BLOCK_SIZE_M:int): """ Generates the grid for a softmax kernel. """ return ( triton.cdiv(M, BLOCK_SIZE_M),1,1) +def gen_smem_size_softmax(BLOCK_SIZE_M:int, BLOCK_SIZE_N:int,num_stages:int): + return (BLOCK_SIZE_N*BLOCK_SIZE_M)*num_stages + @triton.jit def softmax(x_ptr, y_ptr, M: tl.constexpr, N: tl.constexpr,stride_x0:tl.constexpr,stride_x1:tl.constexpr, stride_y0:tl.constexpr, stride_y1:tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): ''' diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py index da1c07b90..febf65ed2 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/__init__.py @@ -1,2 +1,2 @@ -from tritontemplate.backend.cuda.transpose.transpose_10 import transpose_10,gen_grid_transpose_10 -from tritontemplate.backend.cuda.transpose.transpose_0213 import transpose_0213,gen_grid_transpose_0213 +from tritontemplate.backend.cuda.transpose.transpose_10 import transpose_10,gen_grid_transpose_10, gen_smem_size_transpose_10 +from tritontemplate.backend.cuda.transpose.transpose_0213 import transpose_0213,gen_grid_transpose_0213, gen_smem_size_transpose_0213 diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py index 87cfe56b5..cc6b0e28c 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_0213.py @@ -1,12 +1,15 @@ import triton import triton.language as tl -def gen_grid_transpose_0213(D0, D1, D2, D3, BLOCK_SIZE_D1, BLOCK_SIZE_D2): +def gen_grid_transpose_0213(D0:int, D1:int, D2:int, D3:int, BLOCK_SIZE_D1:int, BLOCK_SIZE_D2:int): """ Generates the grid for a transpose kernel. """ return (D0 * D3, triton.cdiv(D2, BLOCK_SIZE_D2), triton.cdiv(D1, BLOCK_SIZE_D1)) +def gen_smem_size_transpose_0213(BLOCK_SIZE_D1:int, BLOCK_SIZE_D2:int,num_stages:int): + return (BLOCK_SIZE_D1*BLOCK_SIZE_D2)*num_stages + #TODO: rewrite for contiguous D3 reading and storing @triton.jit def transpose_0213(x, y, diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py index 44187666b..06f4bb00a 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/transpose/transpose_10.py @@ -10,6 +10,9 @@ def gen_grid_transpose_10(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): triton.cdiv(N, BLOCK_SIZE_N), 1) +def gen_smem_size_transpose_10(BLOCK_SIZE_M:int, BLOCK_SIZE_N:int,num_stages:int): + return (BLOCK_SIZE_N*BLOCK_SIZE_M)*num_stages + @triton.jit def transpose_10( x_ptr, y_ptr, diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py index e69de29bb..ef63624f9 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/utils/__init__.py @@ -0,0 +1 @@ +from tritontemplate.backend.cuda.utils.utils import shape2stride \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py index 8a69f30df..3f0a667b1 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/base.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -1,6 +1,7 @@ from abc import ABC,abstractmethod from pprint import pformat -from typing import Any, Dict, Iterable, List, Optional, Set, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Union, Callable +import inspect from tritontemplate.compiler.dtype import dtype_str_to_triton_signature @@ -141,3 +142,26 @@ def _block_size(x): return 64 else: return 128 + + @staticmethod + def _shrink_shared_mem(func_gen_smem_size:Callable,const_metadata:Dict, dev_smem_size:int,num_stages:int): + + sig = dict(inspect.signature(func_gen_smem_size).parameters) + keys = [key for key in sig if key != "num_stages"] + sig['num_stages'] = num_stages + for key in keys: + sig[key] = const_metadata[key] + + it = 0 + len_keys = len(keys) + while dev_smem_size>=1 + it = (it+1)%len_keys + print(sig) + + for key in keys: + if sig[key] >= 32: + const_metadata[key] = sig[key] + else: + raise ValueError(f'Init num_stages and BLOCK_SIZE ares too large to execute, the exec_params = {sig}') + diff --git a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py index 012c54c19..7d1b5d2ca 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py @@ -1,7 +1,7 @@ from typing import Sequence import triton -from tritontemplate.compiler.utils import get_device_max_shared_memory,get_cuda_device_name +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory,get_cuda_device_name class TritonExecutor: def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_size:Sequence[int],warp_size:int=32,constants:dict=None): @@ -14,15 +14,9 @@ def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_siz self.smemsize = triton_kernel.shared self.device_name = get_cuda_device_name() - try: - self.device_name = get_cuda_device_name() - assert self.smemsize <= get_device_max_shared_memory(self.device_name), \ - f'kernel {self.name} smem size {self.smemsize} exceeds device {self.device_name} max smem size {get_device_max_shared_memory(self.device_name)}' - except KeyError as e: - # Log the error and continue with default values - import logging - logging.warning(f"Unsupported device detected: {str(e)}. Continuing with default configuration.") - self.device_name = "unknown" + assert self.smemsize <= get_cuda_device_max_shared_memory(), \ + f'kernel {self.name} smem size {self.smemsize} exceeds device {self.device_name} max smem size {get_cuda_device_max_shared_memory()}' + def __call__(self, *args, **kwds): return self.triton_kernel[self.gridsize](*args, **kwds) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py index 3563d6e49..02763b7f5 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -5,7 +5,7 @@ from tritontemplate.compiler.base import IntImm, Tensor, Operation from tritontemplate.compiler.dtype import dtype_str_to_triton_signature from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.compiler.utils import get_warpsize, get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride _supported_layouts = ['rcr','rrr','crr','ccr'] @@ -50,7 +50,7 @@ def _deduce_output_shape(self): else: assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" - def _gen_constants(self,enable_tf32): + def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): const_metadata={} any_float32=False for input in self._attrs['inputs']: @@ -60,6 +60,8 @@ def _gen_constants(self,enable_tf32): const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) + + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False input=self._attrs['inputs'] @@ -88,13 +90,15 @@ def compile(self, target_name, workdir,enable_tf32: bool = False,)->TritonExecut triton_kernel_name=f'bmm'+ ('' if not self.is_bias else '_bias') triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),triton_kernel_name) gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_grid_bmm') + func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_smem_size_bmm') - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants(enable_tf32) exec_metadata=self._gen_exec_metadata() num_warps=exec_metadata['num_warps'] num_stages=exec_metadata['num_stages'] + + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) + constants=self._gen_constants(enable_tf32,num_stages,func_gen_smem_size) config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 6b88d9c48..824a2a573 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -6,8 +6,8 @@ from tritontemplate.compiler.base import IntImm, Tensor, Operation from tritontemplate.compiler.dtype import dtype_str_to_triton_signature from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize -from tritontemplate.backend.cuda.utils.utils import shape2stride +from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory +from tritontemplate.backend.cuda.utils import shape2stride _supported_layouts = ['rcr','rrr','ccr','crr'] _supported_activations = ['relu',None] @@ -59,7 +59,7 @@ def _deduce_output_shape(self): else: assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" - def _gen_constants(self,enable_tf32): + def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): const_metadata={} const_metadata['ACTIVATION'] = self._attrs['activation'] @@ -74,7 +74,8 @@ def _gen_constants(self,enable_tf32): const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) - + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + input=self._attrs['inputs'] output=self._attrs['outputs'] const_metadata['M']=self._attrs['M'] @@ -98,15 +99,17 @@ def _gen_exec_metadata(self): #TODO:enable_tf32 https://github.com/triton-lang/triton/issues/4574 def compile(self,target_name,workdir,enable_tf32: bool = False,)->TritonExecutor: triton_kernel_name=f'gemm'+ ('' if not self.is_bias else '_bias') + func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_smem_size_gemm') triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm') signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants(enable_tf32) exec_metadata=self._gen_exec_metadata() num_warps=exec_metadata['num_warps'] num_stages=exec_metadata['num_stages'] + + constants=self._gen_constants(enable_tf32,num_stages,func_gen_smem_size) config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py index 4e2641884..00b1c3633 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -7,7 +7,7 @@ from tritontemplate.compiler.base import IntImm, Tensor, Operation from tritontemplate.compiler.dtype import dtype_str_to_triton_signature from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.compiler.utils import get_warpsize, get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride _exec_metadata = { @@ -25,7 +25,7 @@ def __init__( name: Optional[str] = None, ) -> None: super().__init__(inputs, outputs, name) - assert len(axises)==1 and axises[0] == len(inputs[0].shape)-1, f'Only last axis normalization is supported (axis={axis}, input shape={inputs[0].shape})' + assert len(axises)==1 and axises[0] == len(inputs[0].shape)-1, f'Only last axis normalization is supported (axis={axises}, input shape={inputs[0].shape})' self._attrs['axis'] = axises[0] self._attrs['eps'] = eps @@ -40,7 +40,7 @@ def _deduce_output_shape(self): if self._attrs['outputs'] is None: self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] - def _gen_constants(self): + def _gen_constants(self,num_stages,func_gen_smem_size): const_metadata={} const_metadata['M']= self._attrs['M'] const_metadata['N']= self._attrs['N'] @@ -53,6 +53,8 @@ def _gen_constants(self): const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + if self._attrs['with_weight_bias']: const_metadata['stride_weight'] = 1 const_metadata['stride_bias'] = 1 @@ -66,12 +68,16 @@ def compile(self, target_name, workdir,enable_tf32)->TritonExecutor: triton_kernel_name= 'layernorm_weight_bias' if self._attrs['with_weight_bias'] else 'layernorm' triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),triton_kernel_name) gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),f'gen_grid_layernorm') + func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),f'gen_smem_size_layernorm') + signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants() + exec_metadata=self._gen_exec_metadata() num_warps=exec_metadata['num_warps'] num_stages=exec_metadata['num_stages'] + + constants=self._gen_constants(num_stages,func_gen_smem_size) config = config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index fb0373f82..08fa437ee 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -7,7 +7,7 @@ from tritontemplate.compiler.base import IntImm, Tensor, Operation from tritontemplate.compiler.dtype import dtype_str_to_triton_signature from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride _exec_metadata = { @@ -35,7 +35,7 @@ def _deduce_output_shape(self): self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype='float32')] # self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] - def _gen_constants(self): + def _gen_constants(self,num_stages,func_gen_smem_size): const_metadata={} const_metadata['M']= self._attrs['M'] const_metadata['N']= self._attrs['N'] @@ -46,6 +46,7 @@ def _gen_constants(self): const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) return const_metadata @@ -56,12 +57,15 @@ def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: triton_kernel_name= 'online_softmax' if self._attrs['enable_online'] else 'softmax' triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),triton_kernel_name) gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),f'gen_grid_softmax') + func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),f'gen_smem_size_softmax') signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants() + exec_metadata=self._gen_exec_metadata() num_warps=exec_metadata['num_warps'] num_stages=exec_metadata['num_stages'] + + constants=self._gen_constants(num_stages,func_gen_smem_size) config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py index 1a37495c5..28f155fc5 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -6,7 +6,7 @@ from tritontemplate.compiler.base import IntImm, Tensor, Operation from tritontemplate.compiler.dtype import dtype_str_to_triton_signature from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride _supported_permutations = ['10','0213'] @@ -38,7 +38,7 @@ def _deduce_output_shape(self): else: assert self._attrs['outputs'][0].shape == output_shape, f"Transpose op output shape {self._attrs['outputs'][0].shape} does not match expected shape {output_shape}" - def _gen_constants_10(self): + def _gen_constants_10(self,num_stages,func_gen_smem_size): const_metadata={} M,N=self._attrs['inputs'][0].shape @@ -51,6 +51,7 @@ def _gen_constants_10(self): const_metadata['BLOCK_SIZE_M'] = self._block_size(M) const_metadata['BLOCK_SIZE_N'] = self._block_size(N) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) return const_metadata @@ -58,7 +59,7 @@ def _gen_grid_10(self,target_name,const_metadata): gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_10') return gen_grid(const_metadata['M'],const_metadata['N'],const_metadata['BLOCK_SIZE_M'],const_metadata['BLOCK_SIZE_N']) - def _gen_constants_0213(self): + def _gen_constants_0213(self,num_stages,func_gen_smem_size): const_metadata={} D0,D1,D2,D3=self._attrs['inputs'][0].shape const_metadata['D0'] = D0 @@ -72,6 +73,8 @@ def _gen_constants_0213(self): const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + return const_metadata def _gen_grid_0213(self,target_name,const_metadata): @@ -86,20 +89,21 @@ def _gen_exec_metadata(self): def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: triton_kernel_name= 'transpose_' + self._attrs['permutation'] triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),triton_kernel_name) + func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),f'gen_smem_size_transpose_{self._attrs["permutation"]}') signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) exec_metadata=self._gen_exec_metadata() + num_warps=exec_metadata['num_warps'] + num_stages=exec_metadata['num_stages'] if self._attrs['permutation'] == '10': - constants=self._gen_constants_10() + constants=self._gen_constants_10(num_stages,func_gen_smem_size) exec_grid = self._gen_grid_10(target_name,constants) elif self._attrs['permutation'] == '0213': - constants=self._gen_constants_0213() + constants=self._gen_constants_0213(num_stages,func_gen_smem_size) exec_grid = self._gen_grid_0213(target_name,constants) else: raise ValueError(f"Unsupported permutation {self._attrs['permutation']}") - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/utils.py b/external/TritonTemplate/python/tritontemplate/compiler/utils.py index 28cb397eb..df1d98403 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/utils.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/utils.py @@ -1,16 +1,18 @@ import subprocess +import torch _TARGET2WARPSIZE={ 'cuda':32, } _DEVICE_MAX_SHARED_MEMORY={ - "NVIDIA H800": 227 * 1024, - "NVIDIA H100": 227 * 1024, - "NVIDIA A100": 164 * 1024, - "NVIDIA A800": 164 * 1024, - "NVIDIA V100": 96 * 1024, - "NVIDIA T4": 64 * 1024, + "8.0" : 163*1024, + "8.6" : 99*1024, + "8.7" : 163*1024, + "8.9" : 99*1024, + "9." : 227*1024, + "10." : 227*1024, + "12." : 99*1024, } def get_cuda_device_name(idx=0): @@ -25,9 +27,10 @@ def get_warpsize(target_name): except KeyError: raise KeyError(f'target {target_name} not supported') -def get_device_max_shared_memory(target_name): - try: - return _DEVICE_MAX_SHARED_MEMORY[target_name] - except KeyError: - raise KeyError(f'target {target_name} not supported, please add max smem size info') - \ No newline at end of file +def get_cuda_device_max_shared_memory(): + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] == 8: + return _DEVICE_MAX_SHARED_MEMORY[str(compute_capability[0])+"."+str(compute_capability[1])] + elif compute_capability[0] < 8: + raise KeyError(f'cuda compute capability {compute_capability} does not support triton') + return _DEVICE_MAX_SHARED_MEMORY[str(compute_capability[0])+"."] diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py index b1fdd93d9..71ac3fa6a 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -186,3 +186,5 @@ def test_gemm_relu(format, M, N, K, stype): assert torch.allclose(c_triton_aot, c_triton_jit, atol=atol, rtol=rtol) assert torch.allclose(pytorch_result, c_triton_jit, atol=atol, rtol=rtol) + +test_gemm_relu('rcr', 128, 128, 128, 'float32') \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py index aba410aae..0fc96eb6d 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py @@ -17,13 +17,13 @@ def gen_layernorm(with_weight_bias,batch,seq_len,hidden_size,stype): op=Layernorm( inputs=[X,W,B], outputs=None, - axis=2, + axises=[2], eps=1e-5) else: op=Layernorm( inputs=[X], outputs=None, - axis=2, + axises=[2], eps=1e-5) return op @@ -72,10 +72,10 @@ def test_layernorm(batch,seq_len,hidden_size,stype,format): weight = torch.randn(hidden_size,dtype=dtype,device='cuda') bias = torch.randn(hidden_size,dtype=dtype,device='cuda') y_torch = torch.nn.functional.layer_norm(x,(N,),weight,bias,eps=1e-5) - kernel_layernorm_weight_bias[grid](x,bias,weight,y_triton_jit,M,N,N,1,N,1,1,1,64,64,1e-5) + kernel_layernorm_weight_bias[grid](x,weight,bias,y_triton_jit,M,N,N,1,N,1,1,1,64,64,1e-5) kernel = gen_layernorm(True,batch,seq_len,hidden_size,stype) kernel_aot = compile_kernel(kernel) - kernel_aot(x,bias,weight,y_triton_aot) + kernel_aot(x,weight,bias,y_triton_aot) atol = 1e-2 if dtype == torch.float16 else 1e-4 rtol = 1e-2 if dtype == torch.float16 else 1e-4 From 60a85fd3e3c320ab37b385a610f1bd22ed2677d9 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 25 Jul 2025 22:53:21 +0800 Subject: [PATCH 2/9] [feat] auto block size --- .github/workflows/compiler-ci.yaml | 3 ++ .../tritontemplate/backend/cuda/bmm/bmm.py | 32 ++++++------- .../tritontemplate/backend/cuda/gemm/gemm.py | 46 +++++++++++-------- .../backend/cuda/layernorm/layernorm.py | 12 ++--- .../backend/cuda/softmax/softmax.py | 4 +- .../backend/cuda/transpose/transpose_0213.py | 4 +- .../backend/cuda/transpose/transpose_10.py | 4 +- .../python/tritontemplate/compiler/base.py | 26 ++++++----- .../tritontemplate/compiler/ops/bmm/bmm.py | 6 +-- .../tritontemplate/compiler/ops/gemm/gemm.py | 6 +-- .../compiler/ops/layernorm/layernorm.py | 4 +- .../compiler/ops/softmax/softmax.py | 4 +- .../compiler/ops/transpose/transpose.py | 6 +-- .../tritontemplate/testing/cuda/test_gemm.py | 2 - 14 files changed, 82 insertions(+), 77 deletions(-) diff --git a/.github/workflows/compiler-ci.yaml b/.github/workflows/compiler-ci.yaml index 50bf4eea6..d51c15f20 100644 --- a/.github/workflows/compiler-ci.yaml +++ b/.github/workflows/compiler-ci.yaml @@ -43,6 +43,9 @@ jobs: - name: Run build and test run: ./scripts/compiler/build_and_test.sh shell: bash + - name: Run build and test TritonTemplate + run: ./external/TritonTemplate/scripts/build_and_test.sh + shell: bash - name: Run E2E testcases and check different run: ./compiler/scripts/gen_testcases_and_check_diff.sh shell: bash diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py index b8e59ca9e..ccbca666b 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -7,8 +7,8 @@ def gen_grid_bmm(batch_size:int,M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int """ return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) -def gen_smem_size_bmm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int): - return (BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages +def gen_smem_size_bmm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int,size_dtype:int): + return max((BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages*size_dtype,(BLOCK_SIZE_M*BLOCK_SIZE_N)*4) @triton.jit def bmm_bias( @@ -43,7 +43,7 @@ def bmm_bias( if is_transpose_a: # A(B,K,M) - a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:] + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[None,:]+stride_a2*offs_am[:,None] stride_ak=stride_a1 else: # A(B,M,K) @@ -52,7 +52,7 @@ def bmm_bias( if is_transpose_b: # B(B,N,K) - b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:] + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[None,:]+stride_b2*offs_k[:,None] stride_bk=stride_b2 else: # B(B,K,N) @@ -62,20 +62,16 @@ def bmm_bias( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): if is_transpose_a: - a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k>=1 + tolerance = len_keys + while tolerance and dev_smem_size32: + sig[keys[it]]//=2 + else: + tolerance-=1 it = (it+1)%len_keys - print(sig) - for key in keys: - if sig[key] >= 32: - const_metadata[key] = sig[key] - else: - raise ValueError(f'Init num_stages and BLOCK_SIZE ares too large to execute, the exec_params = {sig}') - + val = sig[key] + if val < 32: + raise ValueError(f'Shrinking resulted in block size < 32. The exec_params = {sig}') + const_metadata[key] = val diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py index 02763b7f5..c51ec1bb2 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -3,7 +3,7 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor from tritontemplate.compiler.utils import get_warpsize, get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride @@ -12,7 +12,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 1, + 'num_stages': 3, } class Bmm(Operation): @@ -61,7 +61,7 @@ def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) - self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False input=self._attrs['inputs'] diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 824a2a573..6a2fff0aa 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -4,7 +4,7 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils import shape2stride @@ -15,7 +15,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 1, + 'num_stages': 3, } class Gemm(Operation): @@ -74,7 +74,7 @@ def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) - self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) input=self._attrs['inputs'] output=self._attrs['outputs'] diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py index 00b1c3633..b80d25b1c 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -5,7 +5,7 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor from tritontemplate.compiler.utils import get_warpsize, get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride @@ -53,7 +53,7 @@ def _gen_constants(self,num_stages,func_gen_smem_size): const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) - self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) if self._attrs['with_weight_bias']: const_metadata['stride_weight'] = 1 diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index 08fa437ee..f5868cd3f 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -5,7 +5,7 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride @@ -46,7 +46,7 @@ def _gen_constants(self,num_stages,func_gen_smem_size): const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) - self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) return const_metadata diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py index 28f155fc5..6464dd14c 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -4,7 +4,7 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride @@ -51,7 +51,7 @@ def _gen_constants_10(self,num_stages,func_gen_smem_size): const_metadata['BLOCK_SIZE_M'] = self._block_size(M) const_metadata['BLOCK_SIZE_N'] = self._block_size(N) - self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) return const_metadata @@ -73,7 +73,7 @@ def _gen_constants_0213(self,num_stages,func_gen_smem_size): const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) - self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) return const_metadata diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py index 71ac3fa6a..b1fdd93d9 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -186,5 +186,3 @@ def test_gemm_relu(format, M, N, K, stype): assert torch.allclose(c_triton_aot, c_triton_jit, atol=atol, rtol=rtol) assert torch.allclose(pytorch_result, c_triton_jit, atol=atol, rtol=rtol) - -test_gemm_relu('rcr', 128, 128, 128, 'float32') \ No newline at end of file From 2a3a8b275751fbfa47c4c266a6f6792f60897c92 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Fri, 25 Jul 2025 23:06:04 +0800 Subject: [PATCH 3/9] [bug] code clean --- external/TritonTemplate/python/pyproject.toml | 33 ++++++++++ external/TritonTemplate/python/setup.py | 59 ----------------- .../tritontemplate/backend/cuda/bmm/bmm.py | 5 +- .../tritontemplate/backend/cuda/gemm/gemm.py | 10 +-- .../python/tritontemplate/compiler/base.py | 1 - .../tritontemplate/testing/cuda/test.sh | 9 +++ .../tritontemplate/testing/cuda/test_bmm.py | 33 ++++++---- .../tritontemplate/testing/cuda/test_gemm.py | 66 +++++++++++-------- .../testing/cuda/test_layernorm.py | 33 ++++++---- .../testing/cuda/test_softmax.py | 29 ++++---- .../testing/cuda/test_transpose.py | 37 +++++++---- .../TritonTemplate/scripts/build_and_test.sh | 47 +++++++++++++ 12 files changed, 214 insertions(+), 148 deletions(-) create mode 100644 external/TritonTemplate/python/pyproject.toml delete mode 100644 external/TritonTemplate/python/setup.py create mode 100755 external/TritonTemplate/python/tritontemplate/testing/cuda/test.sh create mode 100755 external/TritonTemplate/scripts/build_and_test.sh diff --git a/external/TritonTemplate/python/pyproject.toml b/external/TritonTemplate/python/pyproject.toml new file mode 100644 index 000000000..d72a6fd76 --- /dev/null +++ b/external/TritonTemplate/python/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools>=42"] +build-backend = "setuptools.build_meta" + +[project] +name = "tritontemplate" +version = "0.0.1" # You'll need to get this from _libinfo.py differently +description = "TritonTemplate: Make Flex Triton Templates for AI" +requires-python = ">=3.7,<4" +dependencies = [ + "triton==2.1.0", + "torch>=2.1.0", + "numpy<2", + "pytest", +] + +[project.optional-dependencies] +# Add any optional dependencies here if needed + +[tool.setuptools] +zip-safe = true +packages = ["tritontemplate"] + +# Note: For package_data, you'll need to either: +# 1. List all files explicitly here, or +# 2. Keep a minimal setup.py just for the file listing logic +# Here's an example of explicit package data: +[tool.setuptools.package-data] +tritontemplate = [ + "backend/*.py", + "utils/*.py", + "compiler/*.py" +] diff --git a/external/TritonTemplate/python/setup.py b/external/TritonTemplate/python/setup.py deleted file mode 100644 index 8b369e707..000000000 --- a/external/TritonTemplate/python/setup.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil - -from setuptools import find_packages, setup - -CURRENT_DIR = os.path.dirname(__file__) -libinfo_py = os.path.join(CURRENT_DIR, "tritontemplate", "_libinfo.py") -libinfo = {} -with open(libinfo_py, "r") as f: - exec(f.read(), libinfo) -__version__ = libinfo["__version__"] - -def gen_file_list(srcs, f_cond): - file_list = [] - for src in srcs: - for root, _, files in os.walk(src): - value = [] - for file in files: - if f_cond(file): - path = os.path.join(root, file) - value.append(path.replace("tritontemplate/", "")) - file_list.extend(value) - return file_list - -def gen_backend_common_file_list(): - srcs = ["tritontemplate/backend"] - f_cond = lambda x: True if x.endswith(".py") else False - return gen_file_list(srcs, f_cond) - -def gen_utils_file_list(): - srcs = ["tritontemplate/utils"] - f_cond = lambda x: True if x.endswith(".py") else False - return gen_file_list(srcs, f_cond) - -def gen_compiler_file_list(): - srcs = ["tritontemplate/compiler"] - f_cond = lambda x: True if x.endswith(".py") else False - return gen_file_list(srcs, f_cond) - -setup_kwargs = {} -include_libs = True -wheel_include_libs = True - -setup( - name="tritontemplate", - version=__version__, - description="TritonTemplate: Make Flex Triton Templates for AI", - zip_safe=True, - install_requires=["torch>=2.1.0","triton"], - packages=find_packages(), - package_data={ - "tritontemplate": [] - + gen_utils_file_list() - + gen_backend_common_file_list() - + gen_compiler_file_list() - }, - python_requires=">=3.7, <4", - **setup_kwargs -) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py index ccbca666b..72c8e8b5e 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -149,10 +149,7 @@ def bmm( b_mask= (offs_k[:,None]+BLOCK_SIZE_K*k /dev/null && pwd )" +# path to TritonTemplate root +ROOT_PROJ_DIR="$CUR_DIR/.." + +# path to python +PYTHON_PATH="$ROOT_PROJ_DIR/python" + +pushd "$PYTHON_PATH" + # install tritontemplate + python3 -m pip install -e . +popd + +# path to python tests +TEST_DIR="$ROOT_PROJ_DIR/python/tritontemplate/testing/cuda" + +if [[ $TRITON_TEST == "ON" ]]; then + # Run pytest for all test files in cuda directory + pushd "$TEST_DIR" + for file in *.py; do + if [ -f "$file" ]; then + echo "Running pytest on $file" + GITHUB_CI_TEST=true pytest "$file" + fi + done + popd +fi From e714147f2f571934bedad8f7c7afd0e4a6ec3105 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Mon, 28 Jul 2025 12:47:09 +0800 Subject: [PATCH 4/9] [feat] add numerical test --- .../tritontemplate/compiler/ops/bmm/bmm.py | 2 +- .../tritontemplate/compiler/ops/gemm/gemm.py | 2 +- .../compiler/ops/layernorm/layernorm.py | 2 +- .../compiler/ops/softmax/softmax.py | 2 +- .../compiler/ops/transpose/transpose.py | 2 +- .../TritonTemplate/scripts/build_and_test.sh | 18 ++++++++++++++---- tests/build_and_test_e2e.sh | 2 ++ tests/numerical_test/main.py | 6 ++++-- tests/numerical_test/testset.py | 16 ++++++++++++++++ 9 files changed, 41 insertions(+), 11 deletions(-) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py index c51ec1bb2..43bfa7903 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -12,7 +12,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 3, + 'num_stages': 2, } class Bmm(Operation): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 6a2fff0aa..6bd2f6390 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -15,7 +15,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 3, + 'num_stages': 2, } class Gemm(Operation): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py index b80d25b1c..0f1f17eb8 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -12,7 +12,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 1, + 'num_stages': 2, } class Layernorm(Operation): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index f5868cd3f..530a6d35c 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -12,7 +12,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 3, + 'num_stages': 2, } class Softmax(Operation): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py index 6464dd14c..a3ec52360 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -13,7 +13,7 @@ _exec_metadata = { 'num_warps': 4, - 'num_stages': 1, + 'num_stages': 2, } class Transpose(Operation): diff --git a/external/TritonTemplate/scripts/build_and_test.sh b/external/TritonTemplate/scripts/build_and_test.sh index 1fa2afa8f..a196b34c3 100755 --- a/external/TritonTemplate/scripts/build_and_test.sh +++ b/external/TritonTemplate/scripts/build_and_test.sh @@ -3,12 +3,19 @@ set -e set -x +TRITON_BUILD=ON +TRITON_TEST=ON + while [[ $# -gt 0 ]]; do case $1 in --no-test) TRITON_TEST=OFF shift ;; + --no-build) + TRITON_BUILD=OFF + shift + ;; *) echo "Invalid option: $1" exit 1 @@ -16,6 +23,7 @@ while [[ $# -gt 0 ]]; do esac done + TRITON_TEST=${TRITON_TEST:-ON} # path to script @@ -26,10 +34,12 @@ ROOT_PROJ_DIR="$CUR_DIR/.." # path to python PYTHON_PATH="$ROOT_PROJ_DIR/python" -pushd "$PYTHON_PATH" - # install tritontemplate - python3 -m pip install -e . -popd +if [[ $TRITON_BUILD == "ON" ]]; then + pushd "$PYTHON_PATH" + # install tritontemplate + python3 -m pip install -e . + popd +fi # path to python tests TEST_DIR="$ROOT_PROJ_DIR/python/tritontemplate/testing/cuda" diff --git a/tests/build_and_test_e2e.sh b/tests/build_and_test_e2e.sh index ab98d1022..8be3c8618 100755 --- a/tests/build_and_test_e2e.sh +++ b/tests/build_and_test_e2e.sh @@ -17,9 +17,11 @@ pip3 install -r $ROOT_PROJ_DIR/frontends/torch-frontend/torch-cuda-requirements. bash frontends/torch-frontend/scripts/build_and_test.sh --no-test pip3 install $ROOT_PROJ_DIR/external/AITemplate/python/dist/*.whl +pip3 install $ROOT_PROJ_DIR/external/TritonTemplate pip3 install $ROOT_PROJ_DIR/compiler/build/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl + source scripts/prepare.sh install_mhlo_tools diff --git a/tests/numerical_test/main.py b/tests/numerical_test/main.py index a5103459e..90dbdbda3 100644 --- a/tests/numerical_test/main.py +++ b/tests/numerical_test/main.py @@ -25,13 +25,14 @@ GLOBAL_TORCH_TEST_REGISTRY_NAMES, ) from testset import CPU_MLIR_TEST_DIR, CUDA_MLIR_TEST_DIR -from testset import CPU_ALL_SET, CUDA_ALL_SET, CUDA_AIT_ALL_SET, CUDA_AIT_SM80PLUS_SET +from testset import CPU_ALL_SET, CUDA_ALL_SET, CUDA_AIT_ALL_SET, CUDA_AIT_SM80PLUS_SET, CUDA_TIT_MLIR_TEST_SET ##### TEST SET CONFIG ####### TEST_SET = { "cpu": CPU_ALL_SET, "cuda": CUDA_ALL_SET, "cuda_with_ait": CUDA_AIT_ALL_SET, + "cuda_with_triton": CUDA_TIT_MLIR_TEST_SET, } def get_local_gpu_arch(): @@ -93,6 +94,7 @@ def parse_args(): "cpu", "cuda", "cuda_with_ait", + "cuda_with_triton", "dynamo", "native_torch", ], @@ -141,7 +143,7 @@ def main(): results = [] if args.target == "all": - for target in ["cpu", "cuda", "cuda_with_ait", "dynamo"]: + for target in ["cpu", "cuda", "cuda_with_ait", "cuda_with_triton", "dynamo"]: results += run(target, args.filter, args.workdir) else: results += run(args.target, args.filter, args.workdir, mode=args.mode, verbose=args.verbose) diff --git a/tests/numerical_test/testset.py b/tests/numerical_test/testset.py index 02585154e..c7888bded 100755 --- a/tests/numerical_test/testset.py +++ b/tests/numerical_test/testset.py @@ -91,3 +91,19 @@ def _get_test_files_from_dir(directory): } CUDA_AIT_ALL_SET = CUDA_AIT_MLIR_TEST_SET | CUDA_AIT_TORCH_TEST_SET + +##### CUDA TRITONTEMPLATE TEST SET ####### +CUDA_TIT_MLIR_TEST_SET= { + "bmm_rcr.mlir", + "bmm_rrr_add_f16.mlir", + "bmm_rrr_f16.mlir", + "bmm_rrr_permute_f16.mlir", + "bmm_rrr_permute_f32.mlir", + "gemm_crr_f16.mlir", + "gemm_rrr_f16.mlir", + "gemm_rrr_f32.mlir", + "layernorm.mlir", + "softmax.mlir", + "transpose2d.mlir", + "transpose2013.mlir", +} \ No newline at end of file From a035d108ca44420ece37dada38402798f0285fd5 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Mon, 28 Jul 2025 14:57:08 +0800 Subject: [PATCH 5/9] [feat] add numerical test --- tests/build_and_test_e2e.sh | 2 +- tests/numerical_test/testset.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/build_and_test_e2e.sh b/tests/build_and_test_e2e.sh index 8be3c8618..2c3dad53d 100755 --- a/tests/build_and_test_e2e.sh +++ b/tests/build_and_test_e2e.sh @@ -17,7 +17,7 @@ pip3 install -r $ROOT_PROJ_DIR/frontends/torch-frontend/torch-cuda-requirements. bash frontends/torch-frontend/scripts/build_and_test.sh --no-test pip3 install $ROOT_PROJ_DIR/external/AITemplate/python/dist/*.whl -pip3 install $ROOT_PROJ_DIR/external/TritonTemplate +pip3 install $ROOT_PROJ_DIR/external/TritonTemplate/python pip3 install $ROOT_PROJ_DIR/compiler/build/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl diff --git a/tests/numerical_test/testset.py b/tests/numerical_test/testset.py index c7888bded..e9eee3247 100755 --- a/tests/numerical_test/testset.py +++ b/tests/numerical_test/testset.py @@ -97,8 +97,6 @@ def _get_test_files_from_dir(directory): "bmm_rcr.mlir", "bmm_rrr_add_f16.mlir", "bmm_rrr_f16.mlir", - "bmm_rrr_permute_f16.mlir", - "bmm_rrr_permute_f32.mlir", "gemm_crr_f16.mlir", "gemm_rrr_f16.mlir", "gemm_rrr_f32.mlir", From dbf83867407152a1e72a983909859dd48c1ba12e Mon Sep 17 00:00:00 2001 From: liushanghao Date: Mon, 28 Jul 2025 16:03:48 +0800 Subject: [PATCH 6/9] [bug] CI fix --- tests/build_and_test_e2e.sh | 1 - tests/numerical_test/main.py | 6 ++---- tests/numerical_test/testset.py | 14 -------------- 3 files changed, 2 insertions(+), 19 deletions(-) diff --git a/tests/build_and_test_e2e.sh b/tests/build_and_test_e2e.sh index 2c3dad53d..43a3f1e86 100755 --- a/tests/build_and_test_e2e.sh +++ b/tests/build_and_test_e2e.sh @@ -17,7 +17,6 @@ pip3 install -r $ROOT_PROJ_DIR/frontends/torch-frontend/torch-cuda-requirements. bash frontends/torch-frontend/scripts/build_and_test.sh --no-test pip3 install $ROOT_PROJ_DIR/external/AITemplate/python/dist/*.whl -pip3 install $ROOT_PROJ_DIR/external/TritonTemplate/python pip3 install $ROOT_PROJ_DIR/compiler/build/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl diff --git a/tests/numerical_test/main.py b/tests/numerical_test/main.py index 90dbdbda3..a5103459e 100644 --- a/tests/numerical_test/main.py +++ b/tests/numerical_test/main.py @@ -25,14 +25,13 @@ GLOBAL_TORCH_TEST_REGISTRY_NAMES, ) from testset import CPU_MLIR_TEST_DIR, CUDA_MLIR_TEST_DIR -from testset import CPU_ALL_SET, CUDA_ALL_SET, CUDA_AIT_ALL_SET, CUDA_AIT_SM80PLUS_SET, CUDA_TIT_MLIR_TEST_SET +from testset import CPU_ALL_SET, CUDA_ALL_SET, CUDA_AIT_ALL_SET, CUDA_AIT_SM80PLUS_SET ##### TEST SET CONFIG ####### TEST_SET = { "cpu": CPU_ALL_SET, "cuda": CUDA_ALL_SET, "cuda_with_ait": CUDA_AIT_ALL_SET, - "cuda_with_triton": CUDA_TIT_MLIR_TEST_SET, } def get_local_gpu_arch(): @@ -94,7 +93,6 @@ def parse_args(): "cpu", "cuda", "cuda_with_ait", - "cuda_with_triton", "dynamo", "native_torch", ], @@ -143,7 +141,7 @@ def main(): results = [] if args.target == "all": - for target in ["cpu", "cuda", "cuda_with_ait", "cuda_with_triton", "dynamo"]: + for target in ["cpu", "cuda", "cuda_with_ait", "dynamo"]: results += run(target, args.filter, args.workdir) else: results += run(args.target, args.filter, args.workdir, mode=args.mode, verbose=args.verbose) diff --git a/tests/numerical_test/testset.py b/tests/numerical_test/testset.py index e9eee3247..02585154e 100755 --- a/tests/numerical_test/testset.py +++ b/tests/numerical_test/testset.py @@ -91,17 +91,3 @@ def _get_test_files_from_dir(directory): } CUDA_AIT_ALL_SET = CUDA_AIT_MLIR_TEST_SET | CUDA_AIT_TORCH_TEST_SET - -##### CUDA TRITONTEMPLATE TEST SET ####### -CUDA_TIT_MLIR_TEST_SET= { - "bmm_rcr.mlir", - "bmm_rrr_add_f16.mlir", - "bmm_rrr_f16.mlir", - "gemm_crr_f16.mlir", - "gemm_rrr_f16.mlir", - "gemm_rrr_f32.mlir", - "layernorm.mlir", - "softmax.mlir", - "transpose2d.mlir", - "transpose2013.mlir", -} \ No newline at end of file From f9258bfbefd1fceac0c8f7a8690eeba6f3f7b721 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Mon, 28 Jul 2025 20:30:38 +0800 Subject: [PATCH 7/9] [feat] refactor tit compile --- .../backend/cuda/bmm/__init__.py | 4 +- .../tritontemplate/backend/cuda/bmm/bmm.py | 4 +- .../backend/cuda/gemm/__init__.py | 2 + .../backend/cuda/layernorm/__init__.py | 4 +- .../backend/cuda/softmax/__init__.py | 4 +- .../python/tritontemplate/compiler/base.py | 43 ++++++++- .../tritontemplate/compiler/ops/bmm/bmm.py | 36 ++----- .../tritontemplate/compiler/ops/gemm/gemm.py | 34 ++----- .../compiler/ops/layernorm/layernorm.py | 37 ++------ .../compiler/ops/softmax/softmax.py | 34 ++----- .../compiler/ops/transpose/transpose.py | 94 ++++++------------- 11 files changed, 119 insertions(+), 177 deletions(-) diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py index 0302062fd..26bea6627 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py @@ -1 +1,3 @@ -from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm, gen_smem_size_bmm \ No newline at end of file +from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm, gen_smem_size_bmm +gen_grid_bmm_bias = gen_grid_bmm +gen_smem_size_bmm_bias = gen_smem_size_bmm diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py index 72c8e8b5e..8156ba7be 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -1,11 +1,11 @@ import triton import triton.language as tl -def gen_grid_bmm(batch_size:int,M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int): +def gen_grid_bmm(BATCH_SIZE:int,M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int): """ Generates the grid for a Batch GEMM kernel. """ - return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) + return (BATCH_SIZE,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) def gen_smem_size_bmm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int,size_dtype:int): return max((BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages*size_dtype,(BLOCK_SIZE_M*BLOCK_SIZE_N)*4) diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py index 62314bc6e..f260c618c 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/gemm/__init__.py @@ -1 +1,3 @@ from tritontemplate.backend.cuda.gemm.gemm import gemm,gemm_bias,gen_grid_gemm,gen_smem_size_gemm +gen_grid_gemm_bias = gen_grid_gemm +gen_smem_size_gemm_bias = gen_smem_size_gemm \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py index c745404cb..962d33c76 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/layernorm/__init__.py @@ -1 +1,3 @@ -from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm,gen_smem_size_layernorm \ No newline at end of file +from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm,gen_smem_size_layernorm +gen_grid_layernorm_weight_bias = gen_grid_layernorm +gen_smem_size_layernorm_weight_bias = gen_smem_size_layernorm diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py index a636a6c92..e1543eb6c 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/softmax/__init__.py @@ -1 +1,3 @@ -from tritontemplate.backend.cuda.softmax.softmax import softmax,online_softmax,gen_grid_softmax,gen_smem_size_softmax \ No newline at end of file +from tritontemplate.backend.cuda.softmax.softmax import softmax,online_softmax,gen_grid_softmax,gen_smem_size_softmax +gen_grid_online_softmax = gen_grid_softmax +gen_smem_size_online_softmax = gen_smem_size_softmax diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py index 09af873cc..7348fe808 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/base.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -2,7 +2,11 @@ from pprint import pformat from typing import Any, Dict, Iterable, List, Optional, Set, Union, Callable import inspect +import importlib +from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory +from tritontemplate.compiler.dtype import get_dtype_size +from tritontemplate.compiler.kernel import TritonExecutor from tritontemplate.compiler.dtype import dtype_str_to_triton_signature class BaseType(ABC): @@ -105,10 +109,43 @@ def inputs(self) -> List[Tensor]: @property def outputs(self) -> Optional[List[Tensor]]: return self._attrs['outputs'] + + + def _gen_exec_grid(self,gen_grid,constants): + sig = dict(inspect.signature(gen_grid).parameters) + sig = {k:constants[k] for k in sig.keys()} + return gen_grid(**sig) - @abstractmethod - def compile(self,target_name,workdir): - raise NotImplementedError + def compile(self, target_name, workdir, enable_tf32: bool = False) -> TritonExecutor: + + kernel_name = self._kernel_name + backend_module = self._backend_module_name + triton_kernel = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.{backend_module}'), kernel_name) + gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.{backend_module}'), f'gen_grid_{kernel_name}') + func_gen_smem_size = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.{backend_module}'), f'gen_smem_size_{kernel_name}') + + exec_metadata = self._gen_exec_metadata() + num_warps = exec_metadata['num_warps'] + num_stages = exec_metadata['num_stages'] + + signature, divisiability = self._gen_tensor_signature_divisiability(['inputs', 'outputs']) + constants = self._gen_constants(enable_tf32, num_stages, func_gen_smem_size) + + import triton + config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + triton_compiled_kernel = triton.compile( + fn=triton_kernel, + signature=signature, + constants=constants, + num_warps=num_warps, + num_stages=num_stages, + configs=[config], + debug=False + ) + + exec_grid = self._gen_exec_grid(gen_grid, constants) + return TritonExecutor(triton_compiled_kernel, exec_grid, get_warpsize(target_name), constants) + def _gen_tensor_signature_divisiability(self,tensors_names:List[str]): diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py index 43bfa7903..ba78431fc 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -1,19 +1,14 @@ from typing import List,Optional import importlib -import triton -from tritontemplate.compiler.base import IntImm, Tensor, Operation +from tritontemplate.compiler.base import Tensor, Operation from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize, get_cuda_device_max_shared_memory +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride _supported_layouts = ['rcr','rrr','crr','ccr'] -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 2, -} class Bmm(Operation): def __init__( @@ -27,7 +22,10 @@ def __init__( super().__init__(inputs, outputs,name) self.layout = layout self.is_bias = is_bias + self._backend_module_name = 'bmm' + self._kernel_name = self._backend_module_name + ('' if not self.is_bias else '_bias') self._deduce_output_shape() + def _deduce_output_shape(self): BATCH_SIZE = self._attrs['inputs'][0].shape[0] @@ -84,24 +82,10 @@ def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir,enable_tf32: bool = False,)->TritonExecutor: - triton_kernel_name=f'bmm'+ ('' if not self.is_bias else '_bias') - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_grid_bmm') - func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_smem_size_bmm') - - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants(enable_tf32,num_stages,func_gen_smem_size) - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['BATCH_SIZE'],constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file + return super().compile(target_name,workdir,enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 6bd2f6390..9dc40c60d 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -13,11 +13,6 @@ _supported_activations = ['relu',None] -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 2, -} - class Gemm(Operation): def __init__( self, @@ -36,7 +31,9 @@ def __init__( self.is_bias= is_bias self._attrs['activation'] = activation self._deduce_output_shape() - + self._backend_module_name = 'gemm' + self._kernel_name = self._backend_module_name + ('' if not self.is_bias else '_bias') + def _deduce_output_shape(self): @@ -94,26 +91,11 @@ def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } #TODO:enable_tf32 https://github.com/triton-lang/triton/issues/4574 def compile(self,target_name,workdir,enable_tf32: bool = False,)->TritonExecutor: - triton_kernel_name=f'gemm'+ ('' if not self.is_bias else '_bias') - func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_smem_size_gemm') - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm') - - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - - constants=self._gen_constants(enable_tf32,num_stages,func_gen_smem_size) - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) - - + return super().compile(target_name,workdir,enable_tf32) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py index 0f1f17eb8..3958ed1a0 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -7,13 +7,7 @@ from tritontemplate.compiler.base import IntImm, Tensor, Operation from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize, get_cuda_device_max_shared_memory -from tritontemplate.backend.cuda.utils.utils import shape2stride - -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 2, -} +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory class Layernorm(Operation): def __init__( @@ -30,6 +24,8 @@ def __init__( self._attrs['eps'] = eps self._deduce_output_shape() + self._backend_module_name = 'layernorm' + self._kernel_name = self._backend_module_name + ('_weight_bias' if self._attrs['with_weight_bias'] else '') def _deduce_output_shape(self): M = prod(self._attrs['inputs'][0].shape[:-1]) @@ -40,7 +36,7 @@ def _deduce_output_shape(self): if self._attrs['outputs'] is None: self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] - def _gen_constants(self,num_stages,func_gen_smem_size): + def _gen_constants(self,enable_tf32,num_stages,func_gen_smem_size): const_metadata={} const_metadata['M']= self._attrs['M'] const_metadata['N']= self._attrs['N'] @@ -62,26 +58,11 @@ def _gen_constants(self,num_stages,func_gen_smem_size): return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir,enable_tf32)->TritonExecutor: - triton_kernel_name= 'layernorm_weight_bias' if self._attrs['with_weight_bias'] else 'layernorm' - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),f'gen_grid_layernorm') - func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),f'gen_smem_size_layernorm') - - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - - constants=self._gen_constants(num_stages,func_gen_smem_size) - config = config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) - + return super().compile(target_name,workdir,enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index 530a6d35c..9365b21c1 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -7,13 +7,8 @@ from tritontemplate.compiler.base import IntImm, Tensor, Operation from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory -from tritontemplate.backend.cuda.utils.utils import shape2stride +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 2, -} class Softmax(Operation): def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outputs: Optional[List[Tensor]] = None, name: Optional[str] = None): @@ -22,6 +17,8 @@ def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outpu self._attrs['dim'] = dim self._attrs['enable_online'] = enable_online self._deduce_output_shape() + self._backend_module_name = 'softmax' + self._kernel_name = 'online_softmax' if self._attrs['enable_online'] else 'softmax' def _deduce_output_shape(self): M = prod(self._attrs['inputs'][0].shape[:-1]) @@ -35,7 +32,7 @@ def _deduce_output_shape(self): self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype='float32')] # self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] - def _gen_constants(self,num_stages,func_gen_smem_size): + def _gen_constants(self,enable_tf32, num_stages,func_gen_smem_size): const_metadata={} const_metadata['M']= self._attrs['M'] const_metadata['N']= self._attrs['N'] @@ -51,23 +48,10 @@ def _gen_constants(self,num_stages,func_gen_smem_size): return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: - triton_kernel_name= 'online_softmax' if self._attrs['enable_online'] else 'softmax' - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),f'gen_grid_softmax') - func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),f'gen_smem_size_softmax') - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - - constants=self._gen_constants(num_stages,func_gen_smem_size) - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file + return super().compile(target_name,workdir,enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py index a3ec52360..c72228cd9 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -11,10 +11,6 @@ _supported_permutations = ['10','0213'] -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 2, -} class Transpose(Operation): def __init__(self, @@ -25,6 +21,8 @@ def __init__(self, super().__init__(inputs, outputs, name) assert permutation in _supported_permutations, f"Unsupported permutation {permutation}" self._attrs['permutation'] = permutation + self._backend_module_name = 'transpose' + self._kernel_name = self._backend_module_name + '_' + self._attrs['permutation'] self._deduce_output_shape() @@ -38,75 +36,43 @@ def _deduce_output_shape(self): else: assert self._attrs['outputs'][0].shape == output_shape, f"Transpose op output shape {self._attrs['outputs'][0].shape} does not match expected shape {output_shape}" - def _gen_constants_10(self,num_stages,func_gen_smem_size): + def _gen_constants(self, enable_tf32, num_stages, func_gen_smem_size): const_metadata={} - M,N=self._attrs['inputs'][0].shape - - const_metadata['M'] = M - const_metadata['N'] = N - const_metadata['stride_x0'] = N - const_metadata['stride_x1'] = 1 - const_metadata['stride_y0'] = M - const_metadata['stride_y1'] = 1 + if self._attrs['permutation'] == '10': + M,N=self._attrs['inputs'][0].shape - const_metadata['BLOCK_SIZE_M'] = self._block_size(M) - const_metadata['BLOCK_SIZE_N'] = self._block_size(N) - self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) + const_metadata['M'] = M + const_metadata['N'] = N + const_metadata['stride_x0'] = N + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = M + const_metadata['stride_y1'] = 1 - return const_metadata - - def _gen_grid_10(self,target_name,const_metadata): - gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_10') - return gen_grid(const_metadata['M'],const_metadata['N'],const_metadata['BLOCK_SIZE_M'],const_metadata['BLOCK_SIZE_N']) - - def _gen_constants_0213(self,num_stages,func_gen_smem_size): - const_metadata={} - D0,D1,D2,D3=self._attrs['inputs'][0].shape - const_metadata['D0'] = D0 - const_metadata['D1'] = D1 - const_metadata['D2'] = D2 - const_metadata['D3'] = D3 - - const_metadata['stride_x0'],const_metadata['stride_x1'],const_metadata['stride_x2'],const_metadata['stride_x3'] = shape2stride(self._attrs['inputs'][0].shape) - const_metadata['stride_y0'],const_metadata['stride_y1'],const_metadata['stride_y2'],const_metadata['stride_y3']= shape2stride(self._attrs['outputs'][0].shape) - - const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) - const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) + const_metadata['BLOCK_SIZE_M'] = self._block_size(M) + const_metadata['BLOCK_SIZE_N'] = self._block_size(N) + elif self._attrs['permutation'] == '0213': + D0,D1,D2,D3=self._attrs['inputs'][0].shape + const_metadata['D0'] = D0 + const_metadata['D1'] = D1 + const_metadata['D2'] = D2 + const_metadata['D3'] = D3 + + const_metadata['stride_x0'],const_metadata['stride_x1'],const_metadata['stride_x2'],const_metadata['stride_x3'] = shape2stride(self._attrs['inputs'][0].shape) + const_metadata['stride_y0'],const_metadata['stride_y1'],const_metadata['stride_y2'],const_metadata['stride_y3']= shape2stride(self._attrs['outputs'][0].shape) + + const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) + const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) return const_metadata - def _gen_grid_0213(self,target_name,const_metadata): - gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_0213') - return gen_grid(const_metadata['D0'],const_metadata['D1'],const_metadata['D2'],const_metadata['D3'],const_metadata['BLOCK_SIZE_D1'],const_metadata['BLOCK_SIZE_D2']) - - def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: - triton_kernel_name= 'transpose_' + self._attrs['permutation'] - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),triton_kernel_name) - func_gen_smem_size=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),f'gen_smem_size_transpose_{self._attrs["permutation"]}') - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - exec_metadata=self._gen_exec_metadata() - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - - if self._attrs['permutation'] == '10': - constants=self._gen_constants_10(num_stages,func_gen_smem_size) - exec_grid = self._gen_grid_10(target_name,constants) - elif self._attrs['permutation'] == '0213': - constants=self._gen_constants_0213(num_stages,func_gen_smem_size) - exec_grid = self._gen_grid_0213(target_name,constants) - else: - raise ValueError(f"Unsupported permutation {self._attrs['permutation']}") - - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file + return super().compile(target_name, workdir, enable_tf32) \ No newline at end of file From f9366158b4b629ab2f6e35c973e5adde5fefcbcb Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 31 Jul 2025 17:23:15 +0800 Subject: [PATCH 8/9] [feat] triton 3.3.0 supported --- external/TritonTemplate/python/pyproject.toml | 4 +-- .../python/tritontemplate/__init__.py | 1 - .../python/tritontemplate/_libinfo.py | 1 - .../python/tritontemplate/compiler/base.py | 26 ++++++++++++------- .../python/tritontemplate/compiler/kernel.py | 8 +++--- 5 files changed, 23 insertions(+), 17 deletions(-) delete mode 100644 external/TritonTemplate/python/tritontemplate/_libinfo.py diff --git a/external/TritonTemplate/python/pyproject.toml b/external/TritonTemplate/python/pyproject.toml index d72a6fd76..eec475967 100644 --- a/external/TritonTemplate/python/pyproject.toml +++ b/external/TritonTemplate/python/pyproject.toml @@ -4,11 +4,11 @@ build-backend = "setuptools.build_meta" [project] name = "tritontemplate" -version = "0.0.1" # You'll need to get this from _libinfo.py differently +version = "0.0.2" description = "TritonTemplate: Make Flex Triton Templates for AI" requires-python = ">=3.7,<4" dependencies = [ - "triton==2.1.0", + "triton>=3.3.0", "torch>=2.1.0", "numpy<2", "pytest", diff --git a/external/TritonTemplate/python/tritontemplate/__init__.py b/external/TritonTemplate/python/tritontemplate/__init__.py index a1bf64a27..057741191 100644 --- a/external/TritonTemplate/python/tritontemplate/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/__init__.py @@ -1,5 +1,4 @@ import sys from tritontemplate import backend,compiler,testing,utils -from tritontemplate._libinfo import __version__ __all__ = ["backend", "compiler", "testing", "utils"] diff --git a/external/TritonTemplate/python/tritontemplate/_libinfo.py b/external/TritonTemplate/python/tritontemplate/_libinfo.py deleted file mode 100644 index 45c421e85..000000000 --- a/external/TritonTemplate/python/tritontemplate/_libinfo.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "dev0" \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/base.py b/external/TritonTemplate/python/tritontemplate/compiler/base.py index 7348fe808..3da291150 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/base.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/base.py @@ -132,19 +132,27 @@ def compile(self, target_name, workdir, enable_tf32: bool = False) -> TritonExec constants = self._gen_constants(enable_tf32, num_stages, func_gen_smem_size) import triton - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) + attrs = {(v,): [["tt.divisibility", 16]] for v in divisiability[16]} + signature = {triton_kernel.arg_names[i]: s for i, s in signature.items()} + constexprs ={} + constexprs_id={} + constexprs_entry = [] + for key in constants: + signature[key] = 'constexpr' + constexprs[(triton_kernel.arg_names.index(key),)] = constants[key] + constexprs_id[constants[key]] = triton_kernel.arg_names.index(key) + constexprs_entry.append(constants[key]) + + constexprs_entry.sort(key=lambda x: constexprs_id[x]) + + src = triton.compiler.ASTSource(fn=triton_kernel, constexprs=constexprs, signature=signature, attrs=attrs) triton_compiled_kernel = triton.compile( - fn=triton_kernel, - signature=signature, - constants=constants, - num_warps=num_warps, - num_stages=num_stages, - configs=[config], - debug=False + src=src, + options=exec_metadata, ) exec_grid = self._gen_exec_grid(gen_grid, constants) - return TritonExecutor(triton_compiled_kernel, exec_grid, get_warpsize(target_name), constants) + return TritonExecutor(triton_compiled_kernel, exec_grid, get_warpsize(target_name), constexprs_entry) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py index 7d1b5d2ca..bc198c2c5 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py @@ -8,10 +8,10 @@ def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_siz self.call_constants = constants self.triton_kernel = triton_kernel self.gridsize = grid_size - self.blocksize = triton_kernel.num_warps * warp_size + self.blocksize = triton_kernel.metadata.num_warps * warp_size self.warpsize = warp_size - self.name = triton_kernel.metadata['name'] - self.smemsize = triton_kernel.shared + self.name = triton_kernel.metadata.name + self.smemsize = triton_kernel.metadata.shared self.device_name = get_cuda_device_name() assert self.smemsize <= get_cuda_device_max_shared_memory(), \ @@ -19,7 +19,7 @@ def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_siz def __call__(self, *args, **kwds): - return self.triton_kernel[self.gridsize](*args, **kwds) + return self.triton_kernel[self.gridsize](*args,*self.call_constants, **kwds) def kernel_ptx(self,func_name:str): ptx = self.triton_kernel.asm['ptx'] From 670d47a5a9f791522a17a580a3bd03834d2875f4 Mon Sep 17 00:00:00 2001 From: liushanghao Date: Thu, 31 Jul 2025 17:28:19 +0800 Subject: [PATCH 9/9] [bug] code clean --- tests/build_and_test_e2e.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/build_and_test_e2e.sh b/tests/build_and_test_e2e.sh index 43a3f1e86..ab98d1022 100755 --- a/tests/build_and_test_e2e.sh +++ b/tests/build_and_test_e2e.sh @@ -20,7 +20,6 @@ pip3 install $ROOT_PROJ_DIR/external/AITemplate/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/compiler/build/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/runtime/python/dist/*.whl pip3 install $ROOT_PROJ_DIR/frontends/torch-frontend/build/torch-frontend/python/dist/*.whl - source scripts/prepare.sh install_mhlo_tools