diff --git a/.github/workflows/compiler-ci.yaml b/.github/workflows/compiler-ci.yaml index 50bf4eea6..d51c15f20 100644 --- a/.github/workflows/compiler-ci.yaml +++ b/.github/workflows/compiler-ci.yaml @@ -43,6 +43,9 @@ jobs: - name: Run build and test run: ./scripts/compiler/build_and_test.sh shell: bash + - name: Run build and test TritonTemplate + run: ./external/TritonTemplate/scripts/build_and_test.sh + shell: bash - name: Run E2E testcases and check different run: ./compiler/scripts/gen_testcases_and_check_diff.sh shell: bash diff --git a/external/TritonTemplate/python/pyproject.toml b/external/TritonTemplate/python/pyproject.toml new file mode 100644 index 000000000..eec475967 --- /dev/null +++ b/external/TritonTemplate/python/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools>=42"] +build-backend = "setuptools.build_meta" + +[project] +name = "tritontemplate" +version = "0.0.2" +description = "TritonTemplate: Make Flex Triton Templates for AI" +requires-python = ">=3.7,<4" +dependencies = [ + "triton>=3.3.0", + "torch>=2.1.0", + "numpy<2", + "pytest", +] + +[project.optional-dependencies] +# Add any optional dependencies here if needed + +[tool.setuptools] +zip-safe = true +packages = ["tritontemplate"] + +# Note: For package_data, you'll need to either: +# 1. List all files explicitly here, or +# 2. Keep a minimal setup.py just for the file listing logic +# Here's an example of explicit package data: +[tool.setuptools.package-data] +tritontemplate = [ + "backend/*.py", + "utils/*.py", + "compiler/*.py" +] diff --git a/external/TritonTemplate/python/setup.py b/external/TritonTemplate/python/setup.py deleted file mode 100644 index 8b369e707..000000000 --- a/external/TritonTemplate/python/setup.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -import shutil - -from setuptools import find_packages, setup - -CURRENT_DIR = os.path.dirname(__file__) -libinfo_py = os.path.join(CURRENT_DIR, "tritontemplate", "_libinfo.py") -libinfo = {} -with open(libinfo_py, "r") as f: - exec(f.read(), libinfo) -__version__ = libinfo["__version__"] - -def gen_file_list(srcs, f_cond): - file_list = [] - for src in srcs: - for root, _, files in os.walk(src): - value = [] - for file in files: - if f_cond(file): - path = os.path.join(root, file) - value.append(path.replace("tritontemplate/", "")) - file_list.extend(value) - return file_list - -def gen_backend_common_file_list(): - srcs = ["tritontemplate/backend"] - f_cond = lambda x: True if x.endswith(".py") else False - return gen_file_list(srcs, f_cond) - -def gen_utils_file_list(): - srcs = ["tritontemplate/utils"] - f_cond = lambda x: True if x.endswith(".py") else False - return gen_file_list(srcs, f_cond) - -def gen_compiler_file_list(): - srcs = ["tritontemplate/compiler"] - f_cond = lambda x: True if x.endswith(".py") else False - return gen_file_list(srcs, f_cond) - -setup_kwargs = {} -include_libs = True -wheel_include_libs = True - -setup( - name="tritontemplate", - version=__version__, - description="TritonTemplate: Make Flex Triton Templates for AI", - zip_safe=True, - install_requires=["torch>=2.1.0","triton"], - packages=find_packages(), - package_data={ - "tritontemplate": [] - + gen_utils_file_list() - + gen_backend_common_file_list() - + gen_compiler_file_list() - }, - python_requires=">=3.7, <4", - **setup_kwargs -) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/__init__.py b/external/TritonTemplate/python/tritontemplate/__init__.py index a1bf64a27..057741191 100644 --- a/external/TritonTemplate/python/tritontemplate/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/__init__.py @@ -1,5 +1,4 @@ import sys from tritontemplate import backend,compiler,testing,utils -from tritontemplate._libinfo import __version__ __all__ = ["backend", "compiler", "testing", "utils"] diff --git a/external/TritonTemplate/python/tritontemplate/_libinfo.py b/external/TritonTemplate/python/tritontemplate/_libinfo.py deleted file mode 100644 index 45c421e85..000000000 --- a/external/TritonTemplate/python/tritontemplate/_libinfo.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = "dev0" \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py index 5f8c6587a..26bea6627 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/__init__.py @@ -1 +1,3 @@ -from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm \ No newline at end of file +from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm, gen_smem_size_bmm +gen_grid_bmm_bias = gen_grid_bmm +gen_smem_size_bmm_bias = gen_smem_size_bmm diff --git a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py index 38a1813da..8156ba7be 100644 --- a/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/backend/cuda/bmm/bmm.py @@ -1,11 +1,14 @@ import triton import triton.language as tl -def gen_grid_bmm(batch_size,M, N, BLOCK_SIZE_M, BLOCK_SIZE_N): +def gen_grid_bmm(BATCH_SIZE:int,M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int): """ Generates the grid for a Batch GEMM kernel. """ - return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) + return (BATCH_SIZE,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1) + +def gen_smem_size_bmm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int,size_dtype:int): + return max((BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages*size_dtype,(BLOCK_SIZE_M*BLOCK_SIZE_N)*4) @triton.jit def bmm_bias( @@ -40,7 +43,7 @@ def bmm_bias( if is_transpose_a: # A(B,K,M) - a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:] + a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[None,:]+stride_a2*offs_am[:,None] stride_ak=stride_a1 else: # A(B,M,K) @@ -49,7 +52,7 @@ def bmm_bias( if is_transpose_b: # B(B,N,K) - b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:] + b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[None,:]+stride_b2*offs_k[:,None] stride_bk=stride_b2 else: # B(B,K,N) @@ -59,20 +62,16 @@ def bmm_bias( accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): if is_transpose_a: - a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k List[Tensor]: @property def outputs(self) -> Optional[List[Tensor]]: return self._attrs['outputs'] + + + def _gen_exec_grid(self,gen_grid,constants): + sig = dict(inspect.signature(gen_grid).parameters) + sig = {k:constants[k] for k in sig.keys()} + return gen_grid(**sig) - @abstractmethod - def compile(self,target_name,workdir): - raise NotImplementedError + def compile(self, target_name, workdir, enable_tf32: bool = False) -> TritonExecutor: + + kernel_name = self._kernel_name + backend_module = self._backend_module_name + triton_kernel = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.{backend_module}'), kernel_name) + gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.{backend_module}'), f'gen_grid_{kernel_name}') + func_gen_smem_size = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.{backend_module}'), f'gen_smem_size_{kernel_name}') + + exec_metadata = self._gen_exec_metadata() + num_warps = exec_metadata['num_warps'] + num_stages = exec_metadata['num_stages'] + + signature, divisiability = self._gen_tensor_signature_divisiability(['inputs', 'outputs']) + constants = self._gen_constants(enable_tf32, num_stages, func_gen_smem_size) + + import triton + attrs = {(v,): [["tt.divisibility", 16]] for v in divisiability[16]} + signature = {triton_kernel.arg_names[i]: s for i, s in signature.items()} + constexprs ={} + constexprs_id={} + constexprs_entry = [] + for key in constants: + signature[key] = 'constexpr' + constexprs[(triton_kernel.arg_names.index(key),)] = constants[key] + constexprs_id[constants[key]] = triton_kernel.arg_names.index(key) + constexprs_entry.append(constants[key]) + + constexprs_entry.sort(key=lambda x: constexprs_id[x]) + + src = triton.compiler.ASTSource(fn=triton_kernel, constexprs=constexprs, signature=signature, attrs=attrs) + triton_compiled_kernel = triton.compile( + src=src, + options=exec_metadata, + ) + + exec_grid = self._gen_exec_grid(gen_grid, constants) + return TritonExecutor(triton_compiled_kernel, exec_grid, get_warpsize(target_name), constexprs_entry) + def _gen_tensor_signature_divisiability(self,tensors_names:List[str]): @@ -141,3 +187,27 @@ def _block_size(x): return 64 else: return 128 + + @staticmethod + def _shrink_shared_mem(func_gen_smem_size:Callable,const_metadata:Dict, dev_smem_size:int,num_stages:int,size_dtype:int): + + sig = dict(inspect.signature(func_gen_smem_size).parameters) + keys = [key for key in sig.keys() if key != "num_stages" and key != "size_dtype"] + sig.update({"num_stages": num_stages, "size_dtype": size_dtype}) + for key in keys: + sig[key] = const_metadata[key] + + it = 0 + len_keys = len(keys) + tolerance = len_keys + while tolerance and dev_smem_size32: + sig[keys[it]]//=2 + else: + tolerance-=1 + it = (it+1)%len_keys + for key in keys: + val = sig[key] + if val < 32: + raise ValueError(f'Shrinking resulted in block size < 32. The exec_params = {sig}') + const_metadata[key] = val diff --git a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py index 012c54c19..bc198c2c5 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/kernel.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/kernel.py @@ -1,31 +1,25 @@ from typing import Sequence import triton -from tritontemplate.compiler.utils import get_device_max_shared_memory,get_cuda_device_name +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory,get_cuda_device_name class TritonExecutor: def __init__(self,triton_kernel:triton.compiler.compiler.CompiledKernel,grid_size:Sequence[int],warp_size:int=32,constants:dict=None): self.call_constants = constants self.triton_kernel = triton_kernel self.gridsize = grid_size - self.blocksize = triton_kernel.num_warps * warp_size + self.blocksize = triton_kernel.metadata.num_warps * warp_size self.warpsize = warp_size - self.name = triton_kernel.metadata['name'] - self.smemsize = triton_kernel.shared + self.name = triton_kernel.metadata.name + self.smemsize = triton_kernel.metadata.shared self.device_name = get_cuda_device_name() - try: - self.device_name = get_cuda_device_name() - assert self.smemsize <= get_device_max_shared_memory(self.device_name), \ - f'kernel {self.name} smem size {self.smemsize} exceeds device {self.device_name} max smem size {get_device_max_shared_memory(self.device_name)}' - except KeyError as e: - # Log the error and continue with default values - import logging - logging.warning(f"Unsupported device detected: {str(e)}. Continuing with default configuration.") - self.device_name = "unknown" + assert self.smemsize <= get_cuda_device_max_shared_memory(), \ + f'kernel {self.name} smem size {self.smemsize} exceeds device {self.device_name} max smem size {get_cuda_device_max_shared_memory()}' + def __call__(self, *args, **kwds): - return self.triton_kernel[self.gridsize](*args, **kwds) + return self.triton_kernel[self.gridsize](*args,*self.call_constants, **kwds) def kernel_ptx(self,func_name:str): ptx = self.triton_kernel.asm['ptx'] diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py index 3563d6e49..ba78431fc 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/bmm/bmm.py @@ -1,19 +1,14 @@ from typing import List,Optional import importlib -import triton -from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.base import Tensor, Operation +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride _supported_layouts = ['rcr','rrr','crr','ccr'] -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 1, -} class Bmm(Operation): def __init__( @@ -27,7 +22,10 @@ def __init__( super().__init__(inputs, outputs,name) self.layout = layout self.is_bias = is_bias + self._backend_module_name = 'bmm' + self._kernel_name = self._backend_module_name + ('' if not self.is_bias else '_bias') self._deduce_output_shape() + def _deduce_output_shape(self): BATCH_SIZE = self._attrs['inputs'][0].shape[0] @@ -50,7 +48,7 @@ def _deduce_output_shape(self): else: assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" - def _gen_constants(self,enable_tf32): + def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): const_metadata={} any_float32=False for input in self._attrs['inputs']: @@ -60,6 +58,8 @@ def _gen_constants(self,enable_tf32): const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) + + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) const_metadata['enable_tf32'] = True if (enable_tf32 and any_float32) else False input=self._attrs['inputs'] @@ -82,22 +82,10 @@ def _gen_constants(self,enable_tf32): return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir,enable_tf32: bool = False,)->TritonExecutor: - triton_kernel_name=f'bmm'+ ('' if not self.is_bias else '_bias') - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.bmm'),f'gen_grid_bmm') - - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants(enable_tf32) - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['BATCH_SIZE'],constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file + return super().compile(target_name,workdir,enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py index 6b88d9c48..9dc40c60d 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/gemm/gemm.py @@ -4,20 +4,15 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize -from tritontemplate.backend.cuda.utils.utils import shape2stride +from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory +from tritontemplate.backend.cuda.utils import shape2stride _supported_layouts = ['rcr','rrr','ccr','crr'] _supported_activations = ['relu',None] -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 1, -} - class Gemm(Operation): def __init__( self, @@ -36,7 +31,9 @@ def __init__( self.is_bias= is_bias self._attrs['activation'] = activation self._deduce_output_shape() - + self._backend_module_name = 'gemm' + self._kernel_name = self._backend_module_name + ('' if not self.is_bias else '_bias') + def _deduce_output_shape(self): @@ -59,7 +56,7 @@ def _deduce_output_shape(self): else: assert self._attrs['outputs'][0].shape == res_shape, f"output shape {self._attrs['outputs'][0].shape} not match {res_shape}" - def _gen_constants(self,enable_tf32): + def _gen_constants(self,enable_tf32,num_stages, func_gen_smem_size): const_metadata={} const_metadata['ACTIVATION'] = self._attrs['activation'] @@ -74,7 +71,8 @@ def _gen_constants(self,enable_tf32): const_metadata['BLOCK_SIZE_M']= self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N']= self._block_size(self._attrs['N']) const_metadata['BLOCK_SIZE_K']= self._block_size(self._attrs['K']) - + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) + input=self._attrs['inputs'] output=self._attrs['outputs'] const_metadata['M']=self._attrs['M'] @@ -93,24 +91,11 @@ def _gen_constants(self,enable_tf32): return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } #TODO:enable_tf32 https://github.com/triton-lang/triton/issues/4574 def compile(self,target_name,workdir,enable_tf32: bool = False,)->TritonExecutor: - triton_kernel_name=f'gemm'+ ('' if not self.is_bias else '_bias') - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.gemm'),f'gen_grid_gemm') - - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants(enable_tf32) - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['M'],constants['N'],constants['BLOCK_SIZE_M'],constants['BLOCK_SIZE_N']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) - - + return super().compile(target_name,workdir,enable_tf32) diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py index 4e2641884..3958ed1a0 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/layernorm/layernorm.py @@ -5,15 +5,9 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize -from tritontemplate.backend.cuda.utils.utils import shape2stride - -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 1, -} +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory class Layernorm(Operation): def __init__( @@ -25,11 +19,13 @@ def __init__( name: Optional[str] = None, ) -> None: super().__init__(inputs, outputs, name) - assert len(axises)==1 and axises[0] == len(inputs[0].shape)-1, f'Only last axis normalization is supported (axis={axis}, input shape={inputs[0].shape})' + assert len(axises)==1 and axises[0] == len(inputs[0].shape)-1, f'Only last axis normalization is supported (axis={axises}, input shape={inputs[0].shape})' self._attrs['axis'] = axises[0] self._attrs['eps'] = eps self._deduce_output_shape() + self._backend_module_name = 'layernorm' + self._kernel_name = self._backend_module_name + ('_weight_bias' if self._attrs['with_weight_bias'] else '') def _deduce_output_shape(self): M = prod(self._attrs['inputs'][0].shape[:-1]) @@ -40,7 +36,7 @@ def _deduce_output_shape(self): if self._attrs['outputs'] is None: self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] - def _gen_constants(self): + def _gen_constants(self,enable_tf32,num_stages,func_gen_smem_size): const_metadata={} const_metadata['M']= self._attrs['M'] const_metadata['N']= self._attrs['N'] @@ -53,6 +49,8 @@ def _gen_constants(self): const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) + if self._attrs['with_weight_bias']: const_metadata['stride_weight'] = 1 const_metadata['stride_bias'] = 1 @@ -60,22 +58,11 @@ def _gen_constants(self): return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir,enable_tf32)->TritonExecutor: - triton_kernel_name= 'layernorm_weight_bias' if self._attrs['with_weight_bias'] else 'layernorm' - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.layernorm'),f'gen_grid_layernorm') - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants() - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - config = config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) - + return super().compile(target_name,workdir,enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py index fb0373f82..9365b21c1 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/softmax/softmax.py @@ -5,15 +5,10 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize -from tritontemplate.backend.cuda.utils.utils import shape2stride +from tritontemplate.compiler.utils import get_cuda_device_max_shared_memory -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 3, -} class Softmax(Operation): def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outputs: Optional[List[Tensor]] = None, name: Optional[str] = None): @@ -22,6 +17,8 @@ def __init__(self, inputs: List[Tensor], dim: int,enable_online:bool=True, outpu self._attrs['dim'] = dim self._attrs['enable_online'] = enable_online self._deduce_output_shape() + self._backend_module_name = 'softmax' + self._kernel_name = 'online_softmax' if self._attrs['enable_online'] else 'softmax' def _deduce_output_shape(self): M = prod(self._attrs['inputs'][0].shape[:-1]) @@ -35,7 +32,7 @@ def _deduce_output_shape(self): self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype='float32')] # self._attrs['outputs'] = [Tensor(shape=self._attrs['inputs'][0].shape,dtype=self._attrs['inputs'][0].dtype)] - def _gen_constants(self): + def _gen_constants(self,enable_tf32, num_stages,func_gen_smem_size): const_metadata={} const_metadata['M']= self._attrs['M'] const_metadata['N']= self._attrs['N'] @@ -46,24 +43,15 @@ def _gen_constants(self): const_metadata['BLOCK_SIZE_M'] = self._block_size(self._attrs['M']) const_metadata['BLOCK_SIZE_N'] = self._block_size(self._attrs['N']) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) return const_metadata def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: - triton_kernel_name= 'online_softmax' if self._attrs['enable_online'] else 'softmax' - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),triton_kernel_name) - gen_grid=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.softmax'),f'gen_grid_softmax') - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - constants=self._gen_constants() - exec_metadata=self._gen_exec_metadata() - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - exec_grid=gen_grid(constants['M'],constants['BLOCK_SIZE_M']) - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file + return super().compile(target_name,workdir,enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py index 1a37495c5..c72228cd9 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/ops/transpose/transpose.py @@ -4,17 +4,13 @@ import triton from tritontemplate.compiler.base import IntImm, Tensor, Operation -from tritontemplate.compiler.dtype import dtype_str_to_triton_signature +from tritontemplate.compiler.dtype import get_dtype_size from tritontemplate.compiler.kernel import TritonExecutor -from tritontemplate.compiler.utils import get_warpsize +from tritontemplate.compiler.utils import get_warpsize,get_cuda_device_max_shared_memory from tritontemplate.backend.cuda.utils.utils import shape2stride _supported_permutations = ['10','0213'] -_exec_metadata = { - 'num_warps': 4, - 'num_stages': 1, -} class Transpose(Operation): def __init__(self, @@ -25,6 +21,8 @@ def __init__(self, super().__init__(inputs, outputs, name) assert permutation in _supported_permutations, f"Unsupported permutation {permutation}" self._attrs['permutation'] = permutation + self._backend_module_name = 'transpose' + self._kernel_name = self._backend_module_name + '_' + self._attrs['permutation'] self._deduce_output_shape() @@ -38,71 +36,43 @@ def _deduce_output_shape(self): else: assert self._attrs['outputs'][0].shape == output_shape, f"Transpose op output shape {self._attrs['outputs'][0].shape} does not match expected shape {output_shape}" - def _gen_constants_10(self): + def _gen_constants(self, enable_tf32, num_stages, func_gen_smem_size): const_metadata={} - M,N=self._attrs['inputs'][0].shape + if self._attrs['permutation'] == '10': + M,N=self._attrs['inputs'][0].shape - const_metadata['M'] = M - const_metadata['N'] = N - const_metadata['stride_x0'] = N - const_metadata['stride_x1'] = 1 - const_metadata['stride_y0'] = M - const_metadata['stride_y1'] = 1 + const_metadata['M'] = M + const_metadata['N'] = N + const_metadata['stride_x0'] = N + const_metadata['stride_x1'] = 1 + const_metadata['stride_y0'] = M + const_metadata['stride_y1'] = 1 - const_metadata['BLOCK_SIZE_M'] = self._block_size(M) - const_metadata['BLOCK_SIZE_N'] = self._block_size(N) + const_metadata['BLOCK_SIZE_M'] = self._block_size(M) + const_metadata['BLOCK_SIZE_N'] = self._block_size(N) + elif self._attrs['permutation'] == '0213': + D0,D1,D2,D3=self._attrs['inputs'][0].shape + const_metadata['D0'] = D0 + const_metadata['D1'] = D1 + const_metadata['D2'] = D2 + const_metadata['D3'] = D3 + + const_metadata['stride_x0'],const_metadata['stride_x1'],const_metadata['stride_x2'],const_metadata['stride_x3'] = shape2stride(self._attrs['inputs'][0].shape) + const_metadata['stride_y0'],const_metadata['stride_y1'],const_metadata['stride_y2'],const_metadata['stride_y3']= shape2stride(self._attrs['outputs'][0].shape) - return const_metadata - - def _gen_grid_10(self,target_name,const_metadata): - gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_10') - return gen_grid(const_metadata['M'],const_metadata['N'],const_metadata['BLOCK_SIZE_M'],const_metadata['BLOCK_SIZE_N']) - - def _gen_constants_0213(self): - const_metadata={} - D0,D1,D2,D3=self._attrs['inputs'][0].shape - const_metadata['D0'] = D0 - const_metadata['D1'] = D1 - const_metadata['D2'] = D2 - const_metadata['D3'] = D3 - - const_metadata['stride_x0'],const_metadata['stride_x1'],const_metadata['stride_x2'],const_metadata['stride_x3'] = shape2stride(self._attrs['inputs'][0].shape) - const_metadata['stride_y0'],const_metadata['stride_y1'],const_metadata['stride_y2'],const_metadata['stride_y3']= shape2stride(self._attrs['outputs'][0].shape) + const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) + const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) - const_metadata['BLOCK_SIZE_D1'] = self._block_size(D1) - const_metadata['BLOCK_SIZE_D2'] = self._block_size(D2) + self._shrink_shared_mem(func_gen_smem_size,const_metadata,get_cuda_device_max_shared_memory(),num_stages,get_dtype_size(self._attrs['inputs'][0].dtype)) return const_metadata - def _gen_grid_0213(self,target_name,const_metadata): - gen_grid = getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),'gen_grid_transpose_0213') - return gen_grid(const_metadata['D0'],const_metadata['D1'],const_metadata['D2'],const_metadata['D3'],const_metadata['BLOCK_SIZE_D1'],const_metadata['BLOCK_SIZE_D2']) - - def _gen_exec_metadata(self): - return _exec_metadata.copy() + return { + 'num_warps': 4, + 'num_stages': 2, + } def compile(self, target_name, workdir, enable_tf32)->TritonExecutor: - triton_kernel_name= 'transpose_' + self._attrs['permutation'] - triton_kernel=getattr(importlib.import_module(f'tritontemplate.backend.{target_name}.transpose'),triton_kernel_name) - signature,divisiability=self._gen_tensor_signature_divisiability(['inputs','outputs']) - exec_metadata=self._gen_exec_metadata() - - if self._attrs['permutation'] == '10': - constants=self._gen_constants_10() - exec_grid = self._gen_grid_10(target_name,constants) - elif self._attrs['permutation'] == '0213': - constants=self._gen_constants_0213() - exec_grid = self._gen_grid_0213(target_name,constants) - else: - raise ValueError(f"Unsupported permutation {self._attrs['permutation']}") - - num_warps=exec_metadata['num_warps'] - num_stages=exec_metadata['num_stages'] - config = triton.compiler.instance_descriptor(divisible_by_16=divisiability[16], equal_to_1=divisiability[1]) - - triton_compiled_kernel=triton.compile(fn=triton_kernel,signature=signature,constants=constants,num_warps=num_warps,num_stages=num_stages,configs=[config],debug=False) - - - return TritonExecutor(triton_compiled_kernel,exec_grid,get_warpsize(target_name),constants) \ No newline at end of file + return super().compile(target_name, workdir, enable_tf32) \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/compiler/utils.py b/external/TritonTemplate/python/tritontemplate/compiler/utils.py index 28cb397eb..df1d98403 100644 --- a/external/TritonTemplate/python/tritontemplate/compiler/utils.py +++ b/external/TritonTemplate/python/tritontemplate/compiler/utils.py @@ -1,16 +1,18 @@ import subprocess +import torch _TARGET2WARPSIZE={ 'cuda':32, } _DEVICE_MAX_SHARED_MEMORY={ - "NVIDIA H800": 227 * 1024, - "NVIDIA H100": 227 * 1024, - "NVIDIA A100": 164 * 1024, - "NVIDIA A800": 164 * 1024, - "NVIDIA V100": 96 * 1024, - "NVIDIA T4": 64 * 1024, + "8.0" : 163*1024, + "8.6" : 99*1024, + "8.7" : 163*1024, + "8.9" : 99*1024, + "9." : 227*1024, + "10." : 227*1024, + "12." : 99*1024, } def get_cuda_device_name(idx=0): @@ -25,9 +27,10 @@ def get_warpsize(target_name): except KeyError: raise KeyError(f'target {target_name} not supported') -def get_device_max_shared_memory(target_name): - try: - return _DEVICE_MAX_SHARED_MEMORY[target_name] - except KeyError: - raise KeyError(f'target {target_name} not supported, please add max smem size info') - \ No newline at end of file +def get_cuda_device_max_shared_memory(): + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] == 8: + return _DEVICE_MAX_SHARED_MEMORY[str(compute_capability[0])+"."+str(compute_capability[1])] + elif compute_capability[0] < 8: + raise KeyError(f'cuda compute capability {compute_capability} does not support triton') + return _DEVICE_MAX_SHARED_MEMORY[str(compute_capability[0])+"."] diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test.sh b/external/TritonTemplate/python/tritontemplate/testing/cuda/test.sh new file mode 100755 index 000000000..88cf2ec2c --- /dev/null +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +set -e +for file in *.py; do + if [ -f "$file" ]; then + echo "Running pytest on $file" + GITHUB_CI_TEST=true pytest "$file" + fi +done \ No newline at end of file diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py index 83a45fc33..997fc31f8 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_bmm.py @@ -1,5 +1,6 @@ import torch import pytest +import os import triton from tritontemplate.compiler.base import IntImm,Tensor @@ -50,19 +51,25 @@ def gen_bmm(format, batch_size, M, N, K, stype): FORMATS = ['rcr','rrr','ccr','crr'] -MATRIX_PARAMS = [ - (2, 2, 128, 31, 'float32'), - (2, 128, 2, 31, 'float16'), - (2, 128, 128, 31, 'float32'), - (2, 31, 128, 2, 'float16'), - (2, 129, 128, 128, 'float32'), - (2, 128, 257, 512, 'float16'), - (2, 128, 512, 257, 'float32'), - (2, 127, 256, 256, 'float16'), - (2, 128, 511, 512, 'float32'), - (2, 256, 128, 255, 'float16'), - (2, 1, 256, 256, 'float32'), -] +if os.environ.get('GITHUB_CI_TEST'): + MATRIX_PARAMS = [ + (2, 128, 128, 31, 'float32'), + (2, 128, 257, 512, 'float16'), + ] +else: + MATRIX_PARAMS = [ + (2, 2, 128, 31, 'float32'), + (2, 128, 2, 31, 'float16'), + (2, 128, 128, 31, 'float32'), + (2, 31, 128, 2, 'float16'), + (2, 129, 128, 128, 'float32'), + (2, 128, 257, 512, 'float16'), + (2, 128, 512, 257, 'float32'), + (2, 127, 256, 256, 'float16'), + (2, 128, 511, 512, 'float32'), + (2, 256, 128, 255, 'float16'), + (2, 1, 256, 256, 'float32'), + ] @pytest.mark.parametrize('format', FORMATS) @pytest.mark.parametrize( diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py index b1fdd93d9..3f3d72ce1 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_gemm.py @@ -1,6 +1,7 @@ import torch import pytest import triton +import os from tritontemplate.backend.cuda.gemm import gemm_bias as gemm_bias_kernel from tritontemplate.backend.cuda.gemm import gemm as gemm_kernel @@ -52,19 +53,25 @@ def gen_gemm(format, M, N, K, stype): FORMATS = ['rcr','rrr','ccr','crr'] -MATRIX_PARAMS = [ - (2, 128, 31, 'float32'), - (128, 2, 31, 'float16'), - (128, 128, 31, 'float32'), - (128, 31, 2, 'float16'), - (128, 128, 128, 'float32'), - (128, 257, 512, 'float16'), - (128, 512, 257, 'float32'), - (127, 256, 256, 'float16'), - (128, 511, 512, 'float32'), - (256, 128, 255, 'float16'), - (1, 256, 256, 'float32'), -] +if os.environ.get('GITHUB_CI_TEST'): + MATRIX_PARAMS = [ + (128, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + ] +else: + MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 2, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), + ] @pytest.mark.parametrize('format', FORMATS) @pytest.mark.parametrize( @@ -119,20 +126,27 @@ def test_gemm_bias_relu(format, M, N, K, stype): assert torch.allclose(c_triton_aot, c_triton_jit, atol=1e-2, rtol=1e-2) assert torch.allclose(pytorch_result, c_triton_jit, atol=1e-2, rtol=1e-2) + FORMATS = ['rcr','rrr','ccr','crr'] -MATRIX_PARAMS = [ - (2, 128, 31, 'float32'), - (128, 2, 31, 'float16'), - (128, 128, 31, 'float32'), - (128, 31, 2, 'float16'), - (128, 128, 128, 'float32'), - (128, 257, 512, 'float16'), - (128, 512, 257, 'float32'), - (127, 256, 256, 'float16'), - (128, 511, 512, 'float32'), - (256, 128, 255, 'float16'), - (1, 256, 256, 'float32'), -] +if os.environ.get('GITHUB_CI_TEST'): + MATRIX_PARAMS = [ + (128, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + ] +else: + MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 2, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), + ] @pytest.mark.parametrize('format', FORMATS) @pytest.mark.parametrize( diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py index aba410aae..940401c37 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_layernorm.py @@ -1,6 +1,7 @@ import torch import pytest import triton +import os from tritontemplate.backend.cuda.layernorm import layernorm as kernel_layernorm from tritontemplate.backend.cuda.layernorm import layernorm_weight_bias as kernel_layernorm_weight_bias @@ -17,29 +18,35 @@ def gen_layernorm(with_weight_bias,batch,seq_len,hidden_size,stype): op=Layernorm( inputs=[X,W,B], outputs=None, - axis=2, + axises=[2], eps=1e-5) else: op=Layernorm( inputs=[X], outputs=None, - axis=2, + axises=[2], eps=1e-5) return op -MATRIX_PARAMS = [ - (2, 128, 31, 'float32'), - (128, 2, 31, 'float16'), - (128, 128, 31, 'float32'), - (128, 31, 32, 'float16'), - (128, 128, 128, 'float32'), - (128, 257, 512, 'float16'), - (128, 512, 257, 'float32'), - (127, 256, 256, 'float16'), - (128, 511, 512, 'float32'), - (256, 128, 255, 'float16'), - (1, 256, 256, 'float32'), -] +if os.environ.get('GITHUB_CI_TEST'): + MATRIX_PARAMS = [ + (128, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + ] +else: + MATRIX_PARAMS = [ + (2, 128, 31, 'float32'), + (128, 2, 31, 'float16'), + (128, 128, 31, 'float32'), + (128, 31, 32, 'float16'), + (128, 128, 128, 'float32'), + (128, 257, 512, 'float16'), + (128, 512, 257, 'float32'), + (127, 256, 256, 'float16'), + (128, 511, 512, 'float32'), + (256, 128, 255, 'float16'), + (1, 256, 256, 'float32'), + ] FORMATS = ['layernorm','layernorm_weight_bias'] @pytest.mark.parametrize('batch,seq_len,hidden_size,stype',MATRIX_PARAMS) @@ -72,10 +79,10 @@ def test_layernorm(batch,seq_len,hidden_size,stype,format): weight = torch.randn(hidden_size,dtype=dtype,device='cuda') bias = torch.randn(hidden_size,dtype=dtype,device='cuda') y_torch = torch.nn.functional.layer_norm(x,(N,),weight,bias,eps=1e-5) - kernel_layernorm_weight_bias[grid](x,bias,weight,y_triton_jit,M,N,N,1,N,1,1,1,64,64,1e-5) + kernel_layernorm_weight_bias[grid](x,weight,bias,y_triton_jit,M,N,N,1,N,1,1,1,64,64,1e-5) kernel = gen_layernorm(True,batch,seq_len,hidden_size,stype) kernel_aot = compile_kernel(kernel) - kernel_aot(x,bias,weight,y_triton_aot) + kernel_aot(x,weight,bias,y_triton_aot) atol = 1e-2 if dtype == torch.float16 else 1e-4 rtol = 1e-2 if dtype == torch.float16 else 1e-4 diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py index ba8c7b188..29e9a6de4 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_softmax.py @@ -1,6 +1,7 @@ import torch import pytest import triton +import os from tritontemplate.backend.cuda.softmax import softmax as kernel_softmax from tritontemplate.backend.cuda.softmax import online_softmax as kernel_online_softmax @@ -24,18 +25,24 @@ def gen_softmax(is_online,batch,num_heads,seqlen,hidden_dim): 'softmax', 'online_softmax', ] -MATRIX_PARAMS = [ - (128, 16, 8, 255), - (64, 8, 8, 255), - (64, 16, 2, 66), - (128, 16, 8, 257), - (128, 8, 4, 127), - (128, 8, 8, 63), - (64, 8, 4, 129), - (128, 8, 4, 255), - (64, 8, 4, 63), - (64, 8, 2, 255) +if os.environ.get('GITHUB_CI_TEST'): + MATRIX_PARAMS = [ + (64, 8, 8, 255), + (64, 16, 2, 66), ] +else: + MATRIX_PARAMS = [ + (128, 16, 8, 255), + (64, 8, 8, 255), + (64, 16, 2, 66), + (128, 16, 8, 257), + (128, 8, 4, 127), + (128, 8, 8, 63), + (64, 8, 4, 129), + (128, 8, 4, 255), + (64, 8, 4, 63), + (64, 8, 2, 255) + ] @pytest.mark.parametrize('hidden_dim, num_heads, batch, seqlen', MATRIX_PARAMS) @pytest.mark.parametrize('format', FORMATS) diff --git a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py index 159867512..7a7ee654c 100644 --- a/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py +++ b/external/TritonTemplate/python/tritontemplate/testing/cuda/test_transpose.py @@ -1,6 +1,7 @@ import torch import pytest import triton +import os from tritontemplate.backend.cuda.transpose import transpose_10 as kernel_transpose_10 from tritontemplate.backend.cuda.transpose import transpose_0213 as kernel_transpose_0213 @@ -14,19 +15,31 @@ 'transpose_0213', ] -MATRIX_PARAMS = [ - (256,128,'float16'), - (255,257,'float32'), - (127,129,'float16'), - (2304,768,'float16'), -] +if os.environ.get('GITHUB_CI_TEST'): + MATRIX_PARAMS = [ + (256,128,'float16'), + (255,257,'float32'), + ] +else: + MATRIX_PARAMS = [ + (256,128,'float16'), + (255,257,'float32'), + (127,129,'float16'), + (2304,768,'float16'), + ] -TENSOR4D_PARAMS = [ - (16, 64, 128, 32, 'float16'), - (4, 127, 255, 63, 'float32'), - (8, 256, 512, 64, 'float16'), - (8, 1024,8,96, 'float16'), -] +if os.environ.get('GITHUB_CI_TEST'): + TENSOR4D_PARAMS = [ + (16, 64, 128, 32, 'float16'), + (4, 127, 255, 63, 'float32'), + ] +else: + TENSOR4D_PARAMS = [ + (16, 64, 128, 32, 'float16'), + (4, 127, 255, 63, 'float32'), + (8, 256, 512, 64, 'float16'), + (8, 1024,8,96, 'float16'), + ] def gen_transpose_10(M,N,stype): X = Tensor([M, N], stype) diff --git a/external/TritonTemplate/scripts/build_and_test.sh b/external/TritonTemplate/scripts/build_and_test.sh new file mode 100755 index 000000000..a196b34c3 --- /dev/null +++ b/external/TritonTemplate/scripts/build_and_test.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +set -e +set -x + +TRITON_BUILD=ON +TRITON_TEST=ON + +while [[ $# -gt 0 ]]; do + case $1 in + --no-test) + TRITON_TEST=OFF + shift + ;; + --no-build) + TRITON_BUILD=OFF + shift + ;; + *) + echo "Invalid option: $1" + exit 1 + ;; + esac +done + + +TRITON_TEST=${TRITON_TEST:-ON} + +# path to script +CUR_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" +# path to TritonTemplate root +ROOT_PROJ_DIR="$CUR_DIR/.." + +# path to python +PYTHON_PATH="$ROOT_PROJ_DIR/python" + +if [[ $TRITON_BUILD == "ON" ]]; then + pushd "$PYTHON_PATH" + # install tritontemplate + python3 -m pip install -e . + popd +fi + +# path to python tests +TEST_DIR="$ROOT_PROJ_DIR/python/tritontemplate/testing/cuda" + +if [[ $TRITON_TEST == "ON" ]]; then + # Run pytest for all test files in cuda directory + pushd "$TEST_DIR" + for file in *.py; do + if [ -f "$file" ]; then + echo "Running pytest on $file" + GITHUB_CI_TEST=true pytest "$file" + fi + done + popd +fi