Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions src/tilegym/ops/cutile/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import cuda.tile as ct
import torch
import torch.nn as nn
from cuda.tile._numeric_semantics import RoundingMode as RMd

from tilegym.backend import register_impl

Expand All @@ -15,28 +16,29 @@


def sigmoid(x):
return 1.0 / (1.0 + ct.exp(-x))
denom = ct.add(1.0, ct.exp(-x), flush_to_zero=True)
return ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=RMd.APPROX)
Copy link
Contributor Author

@aghilann aghilann Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good chunk of the savings came from Rmd.APPROX without losing precision - verified via tests



def silu(x):
return x * sigmoid(x)
return ct.mul(x, sigmoid(x), flush_to_zero=True)


def ceildiv(a, b):
return -(a // -b)


@ct.kernel
def swiglu_forward_kernel(a, b, c, TILE_SIZE: ct.Constant[int]):
row = ct.bid(0)
col = ct.bid(1)

a_tile = ct.load(a, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO)
b_tile = ct.load(b, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO)
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)

# Sigmoid requires type float32
c_tile = silu(a_tile.astype(ct.float32)).astype(a.dtype) * b_tile
ct.store(c, index=(row, col), tile=c_tile)
a_tile = ct.gather(a, (row, offsets), check_bounds=True, padding_value=0.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good chunk of the perf improvements came from gather scatter vs load/store

b_tile = ct.gather(b, (row, offsets), check_bounds=True, padding_value=0.0)


def ceildiv(a, b):
return -(a // -b)
a_tile_f32 = a_tile.astype(ct.float32)
c_tile = silu(a_tile_f32).astype(a.dtype) * b_tile
ct.scatter(c, (row, offsets), c_tile, check_bounds=True)


def swiglu_forward(a, b):
Expand All @@ -51,18 +53,16 @@ def swiglu_forward(a, b):
c = torch.empty_like(a)
n_rows = a.shape[0]

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
TILE_N = ceildiv(NUM_SMS, n_rows)
TILE_SIZE = next_power_of_2(int(n_cols / TILE_N))
grid = (n_rows, ceildiv(n_cols, TILE_SIZE), 1)
TILE_SIZE = next_power_of_2(n_cols)
grid = (n_rows,)
ct.launch(
torch.cuda.current_stream(),
grid,
swiglu_forward_kernel,
(
a.data,
b.data,
c.data,
a,
b,
c,
TILE_SIZE,
),
)
Expand All @@ -89,7 +89,7 @@ def swiglu_backward_kernel(dc, a, b, da, db, TILE_SIZE: ct.Constant[int]):
b_tile_f32 = b_tile.astype(ct.float32)

# Compute sigmoid(a) and silu(a)
sigmoid_a = sigmoid(a_tile_f32)
sigmoid_a = 1.0 / (1.0 + ct.exp(-a_tile_f32))
Copy link
Contributor Author

@aghilann aghilann Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlined this for now because I didn’t want to modify the backward kernel in this PR - that would require re-benchmarking it as well. I have additional optimizations planned that I’ll include in a separate PR, which will also make use of the new sigmoid implementation I added.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a clear comment here? The current changes make the forward and backward codes a bit confusing. You can delete the comment when the backward PR is ready.

silu_a = a_tile_f32 * sigmoid_a

# db = dc * silu(a)
Expand Down
23 changes: 15 additions & 8 deletions tests/benchmark/bench_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def get_supported_backends():
return [p for p in ALL_BACKENDS if p is not None]


def create_benchmark_config(batch_size, M):
def create_benchmark_config(batch_size, M, dtype):
"""Create a benchmark configuration for given parameters"""
available_backends = get_supported_backends()
if not available_backends:
return None

backends, names, styles = zip(*available_backends)

dtype_name = str(dtype).split(".")[-1]
return triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 15)], # 1024 to 16384
Expand All @@ -54,38 +55,44 @@ def create_benchmark_config(batch_size, M):
line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"swiglu-batch{batch_size}-M{M}-GBps",
plot_name=f"swiglu-batch{batch_size}-M{M}-{dtype_name}-GBps",
args={
"batch_size": batch_size,
"M": M,
"dtype": dtype,
},
)


@triton.testing.perf_report(
[
create_benchmark_config(batch_size, M)
for batch_size in [1, 8] # Different batch sizes
create_benchmark_config(batch_size, M, dtype)
for batch_size in [1, 4, 8] # Different batch sizes
for M in [128, 4096] # Different rows
for dtype in [torch.float16, torch.bfloat16, torch.float32]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most benchmarks test across various dtypes, I thought this one should too

]
)
def bench_swiglu(
batch_size,
M,
N,
backend,
dtype,
device=DEVICE,
):
dtype = torch.float16

# Generate input data: two tensors for SwiGLU operation
a = torch.randn(batch_size, M, N, device=device, dtype=dtype)
b = torch.randn(batch_size, M, N, device=device, dtype=dtype)

# Use unified dispatch system
fn = lambda: tilegym.ops.get_swiglu(backend=backend)(a, b)
ref = lambda: reference_swiglu(a, b)
torch.testing.assert_close(fn(), ref(), atol=1e-2, rtol=1e-2)
if dtype is torch.float32:
ref = lambda: reference_swiglu(a, b)
atol, rtol = 1e-5, 1e-5
else:
ref = lambda: reference_swiglu(a, b)
atol, rtol = 1e-2, 1e-2
torch.testing.assert_close(fn(), ref(), atol=atol, rtol=rtol)

# Benchmark the function
ms = triton.testing.do_bench_cudagraph(fn)
Expand Down