Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/compiler-ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ jobs:
- name: Run build and test
run: ./scripts/compiler/build_and_test.sh
shell: bash
- name: Run build and test TritonTemplate
run: ./external/TritonTemplate/scripts/build_and_test.sh
shell: bash
- name: Run E2E testcases and check different
run: ./compiler/scripts/gen_testcases_and_check_diff.sh
shell: bash
33 changes: 33 additions & 0 deletions external/TritonTemplate/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
[build-system]
requires = ["setuptools>=42"]
build-backend = "setuptools.build_meta"

[project]
name = "tritontemplate"
version = "0.0.2"
description = "TritonTemplate: Make Flex Triton Templates for AI"
requires-python = ">=3.7,<4"
dependencies = [
"triton>=3.3.0",
"torch>=2.1.0",
"numpy<2",
"pytest",
]

[project.optional-dependencies]
# Add any optional dependencies here if needed

[tool.setuptools]
zip-safe = true
packages = ["tritontemplate"]

# Note: For package_data, you'll need to either:
# 1. List all files explicitly here, or
# 2. Keep a minimal setup.py just for the file listing logic
# Here's an example of explicit package data:
[tool.setuptools.package-data]
tritontemplate = [
"backend/*.py",
"utils/*.py",
"compiler/*.py"
]
59 changes: 0 additions & 59 deletions external/TritonTemplate/python/setup.py

This file was deleted.

1 change: 0 additions & 1 deletion external/TritonTemplate/python/tritontemplate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import sys
from tritontemplate import backend,compiler,testing,utils
from tritontemplate._libinfo import __version__

__all__ = ["backend", "compiler", "testing", "utils"]
1 change: 0 additions & 1 deletion external/TritonTemplate/python/tritontemplate/_libinfo.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm
from tritontemplate.backend.cuda.bmm.bmm import bmm,bmm_bias, gen_grid_bmm, gen_smem_size_bmm
gen_grid_bmm_bias = gen_grid_bmm
gen_smem_size_bmm_bias = gen_smem_size_bmm
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import triton
import triton.language as tl

def gen_grid_bmm(batch_size,M, N, BLOCK_SIZE_M, BLOCK_SIZE_N):
def gen_grid_bmm(BATCH_SIZE:int,M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int):
"""
Generates the grid for a Batch GEMM kernel.
"""
return (batch_size,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1)
return (BATCH_SIZE,triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N), 1)

def gen_smem_size_bmm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int,size_dtype:int):
return max((BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages*size_dtype,(BLOCK_SIZE_M*BLOCK_SIZE_N)*4)

@triton.jit
def bmm_bias(
Expand Down Expand Up @@ -40,7 +43,7 @@ def bmm_bias(

if is_transpose_a:
# A(B,K,M)
a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:]
a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[None,:]+stride_a2*offs_am[:,None]
stride_ak=stride_a1
else:
# A(B,M,K)
Expand All @@ -49,7 +52,7 @@ def bmm_bias(

if is_transpose_b:
# B(B,N,K)
b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:]
b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[None,:]+stride_b2*offs_k[:,None]
stride_bk=stride_b2
else:
# B(B,K,N)
Expand All @@ -59,20 +62,16 @@ def bmm_bias(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if is_transpose_a:
a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k<K) & (offs_am[None,:]<M)
a_mask= (offs_k[None,:]+BLOCK_SIZE_K*k<K) & (offs_am[:,None]<M)
else:
a_mask= (offs_am[:,None]<M) & (offs_k[None,:]+BLOCK_SIZE_K*k<K)
a=tl.load(a_ptrs,mask=a_mask)
if is_transpose_b:
b_mask= (offs_bn[:,None]<N) & (offs_k[None,:]+BLOCK_SIZE_K*k<K)
b_mask= (offs_bn[None,:]<N) & (offs_k[:,None]+BLOCK_SIZE_K*k<K)
else:
b_mask= (offs_k[:,None]+BLOCK_SIZE_K*k<K) & (offs_bn[None,:]<N)
b=tl.load(b_ptrs,mask=b_mask)

if is_transpose_a:
a=tl.trans(a)
if is_transpose_b:
b=tl.trans(b)
accumulator += tl.dot(a, b, allow_tf32=enable_tf32)

a_ptrs+=BLOCK_SIZE_K*stride_ak
Expand Down Expand Up @@ -121,7 +120,7 @@ def bmm(

if is_transpose_a:
# A(B,K,M)
a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[:,None]+stride_a2*offs_am[None,:]
a_ptrs = a_ptr+stride_a0*batch_id+stride_a1*offs_k[None,:]+stride_a2*offs_am[:,None]
stride_ak=stride_a1
else:
# A(B,M,K)
Expand All @@ -130,7 +129,7 @@ def bmm(

if is_transpose_b:
# B(B,N,K)
b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[:,None]+stride_b2*offs_k[None,:]
b_ptrs = b_ptr+stride_b0*batch_id+stride_b1*offs_bn[None,:]+stride_b2*offs_k[:,None]
stride_bk=stride_b2
else:
# B(B,K,N)
Expand All @@ -140,20 +139,17 @@ def bmm(
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
if is_transpose_a:
a_mask= (offs_k[:,None]+BLOCK_SIZE_K*k<K) & (offs_am[None,:]<M)
a_mask= (offs_k[None,:]+BLOCK_SIZE_K*k<K) & (offs_am[:,None]<M)
else:
a_mask= (offs_am[:,None]<M) & (offs_k[None,:]+BLOCK_SIZE_K*k<K)
a=tl.load(a_ptrs,mask=a_mask)
if is_transpose_b:
b_mask= (offs_bn[:,None]<N) & (offs_k[None,:]+BLOCK_SIZE_K*k<K)
b_mask= (offs_bn[None,:]<N) & (offs_k[:,None]+BLOCK_SIZE_K*k<K)
else:
b_mask= (offs_k[:,None]+BLOCK_SIZE_K*k<K) & (offs_bn[None,:]<N)
b=tl.load(b_ptrs,mask=b_mask)

if is_transpose_a:
a=tl.trans(a)
if is_transpose_b:
b=tl.trans(b)

accumulator += tl.dot(a, b, allow_tf32=enable_tf32)

a_ptrs+=BLOCK_SIZE_K*stride_ak
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from tritontemplate.backend.cuda.gemm.gemm import gemm,gemm_bias,gen_grid_gemm
from tritontemplate.backend.cuda.gemm.gemm import gemm,gemm_bias,gen_grid_gemm,gen_smem_size_gemm
gen_grid_gemm_bias = gen_grid_gemm
gen_smem_size_gemm_bias = gen_smem_size_gemm
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@
import triton.language as tl
from tritontemplate.backend.cuda.utils.activation import *

def gen_grid_gemm(M, N, BLOCK_SIZE_M, BLOCK_SIZE_N):
def gen_grid_gemm(M:int, N:int, BLOCK_SIZE_M:int, BLOCK_SIZE_N:int):
"""
Generates the grid for a GEMM kernel.
"""
return (
triton.cdiv(M, BLOCK_SIZE_M)*triton.cdiv(N, BLOCK_SIZE_N),1,1)

# smem_size=val*dtype_size*num_stage
smem_demand_per_stage ={
'gemm_bias': 128*128*2,
'gemm': 128*128*2,
}
def gen_smem_size_gemm(BLOCK_SIZE_M:int, BLOCK_SIZE_K:int, BLOCK_SIZE_N:int,num_stages:int,size_dtype:int):
# not exactly predict, just an upper bound
return max((BLOCK_SIZE_N*BLOCK_SIZE_K+BLOCK_SIZE_K*BLOCK_SIZE_M)*num_stages*size_dtype,(BLOCK_SIZE_M*BLOCK_SIZE_N)*4)

@triton.jit
def gemm_bias(
Expand Down Expand Up @@ -50,7 +48,7 @@ def gemm_bias(

if is_transpose_a:
# A(K,M)
a_ptrs = a_ptr + offs_k[:, None] * stride_a0 + offs_am[None, :] * stride_a1
a_ptrs = a_ptr + offs_k[None,:] * stride_a0 + offs_am[:,None] * stride_a1
stride_ak = stride_a0
else:
# A(M,K)
Expand All @@ -59,31 +57,30 @@ def gemm_bias(

if is_transpose_b:
# B(N,K)
b_ptrs = b_ptr + offs_bn[:, None] * stride_b0 + offs_k[None, :] * stride_b1
b_ptrs = b_ptr + offs_bn[None,:]* stride_b0 + offs_k[:,None] * stride_b1
stride_bk = stride_b1
else:
# B(K,N)
b_ptrs = b_ptr + offs_k[:, None] * stride_b0 + offs_bn[None, :] * stride_b1
stride_bk = stride_b0
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):

num_k_block:tl.constexpr = (K+BLOCK_SIZE_K-1)//BLOCK_SIZE_K

for k in range(0, num_k_block):
if is_transpose_a:
a_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_am[None, :] < M)
a_mask = (offs_k[None, :] + BLOCK_SIZE_K * k < K) & (offs_am[:, None] < M)
else:
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + BLOCK_SIZE_K * k < K)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)

if is_transpose_b:
b_mask = (offs_bn[:, None] < N) & (offs_k[None, :] + BLOCK_SIZE_K * k < K)
b_mask = (offs_bn[None,:] < N) & (offs_k[:,None] + BLOCK_SIZE_K * k < K)
else:
b_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_bn[None, :] < N)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)

if is_transpose_a:
a = tl.trans(a)
if is_transpose_b:
b = tl.trans(b)
accumulator += tl.dot(a, b, allow_tf32=enable_tf32)

a_ptrs += BLOCK_SIZE_K * stride_ak
Expand Down Expand Up @@ -134,9 +131,11 @@ def gemm(
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)

num_k_block:tl.constexpr = (K+BLOCK_SIZE_K-1)//BLOCK_SIZE_K

if is_transpose_a:
# A(K,M)
a_ptrs = a_ptr + offs_k[:, None] * stride_a0 + offs_am[None, :] * stride_a1
a_ptrs = a_ptr + offs_k[None,:] * stride_a0 + offs_am[:,None] * stride_a1
stride_ak = stride_a0
else:
# A(M,K)
Expand All @@ -145,31 +144,27 @@ def gemm(

if is_transpose_b:
# B(N,K)
b_ptrs = b_ptr + offs_bn[:, None] * stride_b0 + offs_k[None, :] * stride_b1
b_ptrs = b_ptr + offs_bn[None,:] * stride_b0 + offs_k[:,None] * stride_b1
stride_bk = stride_b1
else:
# B(K,N)
b_ptrs = b_ptr + offs_k[:, None] * stride_b0 + offs_bn[None, :] * stride_b1
stride_bk = stride_b0
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
for k in range(0, num_k_block):
if is_transpose_a:
a_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_am[None, :] < M)
a_mask = (offs_k[None,:] + BLOCK_SIZE_K * k < K) & (offs_am[:,None] < M)
else:
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + BLOCK_SIZE_K * k < K)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)

if is_transpose_b:
b_mask = (offs_bn[:, None] < N) & (offs_k[None, :] + BLOCK_SIZE_K * k < K)
b_mask = (offs_bn[None,:] < N) & (offs_k[:,None] + BLOCK_SIZE_K * k < K)
else:
b_mask = (offs_k[:, None] + BLOCK_SIZE_K * k < K) & (offs_bn[None, :] < N)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)

if is_transpose_a:
a = tl.trans(a)
if is_transpose_b:
b = tl.trans(b)

accumulator += tl.dot(a, b, allow_tf32=enable_tf32)

a_ptrs += BLOCK_SIZE_K * stride_ak
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm
from tritontemplate.backend.cuda.layernorm.layernorm import layernorm,layernorm_weight_bias, gen_grid_layernorm,gen_smem_size_layernorm
gen_grid_layernorm_weight_bias = gen_grid_layernorm
gen_smem_size_layernorm_weight_bias = gen_smem_size_layernorm
Loading