From ad6272bad975d06102966e037d1f79ad07c134ad Mon Sep 17 00:00:00 2001 From: Aghilan Nathan Date: Mon, 23 Feb 2026 06:27:09 +0000 Subject: [PATCH 1/6] Optimize SwiGLU forward path with fast sigmoid math --- src/tilegym/ops/cutile/swiglu.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/tilegym/ops/cutile/swiglu.py b/src/tilegym/ops/cutile/swiglu.py index a73c111..457ddd6 100644 --- a/src/tilegym/ops/cutile/swiglu.py +++ b/src/tilegym/ops/cutile/swiglu.py @@ -6,6 +6,7 @@ import cuda.tile as ct import torch import torch.nn as nn +from cuda.tile._numeric_semantics import RoundingMode as RMd from tilegym.backend import register_impl @@ -13,32 +14,33 @@ PAD_ZERO = ct.PaddingMode.ZERO - def sigmoid(x): - return 1.0 / (1.0 + ct.exp(-x)) + denom = ct.add(1.0, ct.exp(-x), flush_to_zero=True) + return ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=RMd.APPROX) def silu(x): return x * sigmoid(x) +def ceildiv(a, b): + return -(a // -b) + + @ct.kernel def swiglu_forward_kernel(a, b, c, TILE_SIZE: ct.Constant[int]): row = ct.bid(0) col = ct.bid(1) - a_tile = ct.load(a, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO) - b_tile = ct.load(b, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO) + a_tile = ct.load(a, index=(row, col), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO) + b_tile = ct.load(b, index=(row, col), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO) - # Sigmoid requires type float32 - c_tile = silu(a_tile.astype(ct.float32)).astype(a.dtype) * b_tile + # Forward uses fast math knobs for throughput on Blackwell. + a_tile_f32 = a_tile.astype(ct.float32) + c_tile = silu(a_tile_f32).astype(a.dtype) * b_tile ct.store(c, index=(row, col), tile=c_tile) -def ceildiv(a, b): - return -(a // -b) - - def swiglu_forward(a, b): """ a: (batch_size, seq_len, intermediate_size) @@ -60,9 +62,9 @@ def swiglu_forward(a, b): grid, swiglu_forward_kernel, ( - a.data, - b.data, - c.data, + a, + b, + c, TILE_SIZE, ), ) @@ -89,7 +91,7 @@ def swiglu_backward_kernel(dc, a, b, da, db, TILE_SIZE: ct.Constant[int]): b_tile_f32 = b_tile.astype(ct.float32) # Compute sigmoid(a) and silu(a) - sigmoid_a = sigmoid(a_tile_f32) + sigmoid_a = 1.0 / (1.0 + ct.exp(-a_tile_f32)) silu_a = a_tile_f32 * sigmoid_a # db = dc * silu(a) From a0c86c38fd2c39565b4ca223529f56c5051a8a37 Mon Sep 17 00:00:00 2001 From: Aghilan Nathan Date: Mon, 23 Feb 2026 06:27:09 +0000 Subject: [PATCH 2/6] feat: swiglu forward only changes --- src/tilegym/ops/cutile/swiglu.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/tilegym/ops/cutile/swiglu.py b/src/tilegym/ops/cutile/swiglu.py index 457ddd6..4e3b5c0 100644 --- a/src/tilegym/ops/cutile/swiglu.py +++ b/src/tilegym/ops/cutile/swiglu.py @@ -30,15 +30,15 @@ def ceildiv(a, b): @ct.kernel def swiglu_forward_kernel(a, b, c, TILE_SIZE: ct.Constant[int]): row = ct.bid(0) - col = ct.bid(1) + offsets = ct.arange(TILE_SIZE, dtype=ct.int32) - a_tile = ct.load(a, index=(row, col), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO) - b_tile = ct.load(b, index=(row, col), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.ZERO) + a_tile = ct.gather(a, (row, offsets), check_bounds=True, padding_value=0.0) + b_tile = ct.gather(b, (row, offsets), check_bounds=True, padding_value=0.0) # Forward uses fast math knobs for throughput on Blackwell. a_tile_f32 = a_tile.astype(ct.float32) c_tile = silu(a_tile_f32).astype(a.dtype) * b_tile - ct.store(c, index=(row, col), tile=c_tile) + ct.scatter(c, (row, offsets), c_tile, check_bounds=True) def swiglu_forward(a, b): @@ -53,10 +53,8 @@ def swiglu_forward(a, b): c = torch.empty_like(a) n_rows = a.shape[0] - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - TILE_N = ceildiv(NUM_SMS, n_rows) - TILE_SIZE = next_power_of_2(int(n_cols / TILE_N)) - grid = (n_rows, ceildiv(n_cols, TILE_SIZE), 1) + TILE_SIZE = next_power_of_2(n_cols) + grid = (n_rows,) ct.launch( torch.cuda.current_stream(), grid, From dd5d0ffeaa9ccf1d4d45e298869463b8cab96b83 Mon Sep 17 00:00:00 2001 From: Aghilan Nathan Date: Mon, 23 Feb 2026 06:27:09 +0000 Subject: [PATCH 3/6] Tune SwiGLU forward with minimal Blackwell fast-math tweak --- src/tilegym/ops/cutile/swiglu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tilegym/ops/cutile/swiglu.py b/src/tilegym/ops/cutile/swiglu.py index 4e3b5c0..89237d1 100644 --- a/src/tilegym/ops/cutile/swiglu.py +++ b/src/tilegym/ops/cutile/swiglu.py @@ -20,7 +20,7 @@ def sigmoid(x): def silu(x): - return x * sigmoid(x) + return ct.mul(x, sigmoid(x), flush_to_zero=True) def ceildiv(a, b): From ea516f12657f1912eb34c52fbe870eec8ea50993 Mon Sep 17 00:00:00 2001 From: Aghilan Nathan Date: Mon, 23 Feb 2026 06:27:09 +0000 Subject: [PATCH 4/6] feat: lint --- src/tilegym/ops/cutile/swiglu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tilegym/ops/cutile/swiglu.py b/src/tilegym/ops/cutile/swiglu.py index 89237d1..964d478 100644 --- a/src/tilegym/ops/cutile/swiglu.py +++ b/src/tilegym/ops/cutile/swiglu.py @@ -14,6 +14,7 @@ PAD_ZERO = ct.PaddingMode.ZERO + def sigmoid(x): denom = ct.add(1.0, ct.exp(-x), flush_to_zero=True) return ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=RMd.APPROX) From 0461595bf71c4abfe711958f42cd67c793e55287 Mon Sep 17 00:00:00 2001 From: Aghilan Nathan Date: Mon, 23 Feb 2026 06:27:09 +0000 Subject: [PATCH 5/6] feat: lint --- src/tilegym/ops/cutile/swiglu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tilegym/ops/cutile/swiglu.py b/src/tilegym/ops/cutile/swiglu.py index 964d478..4f15f92 100644 --- a/src/tilegym/ops/cutile/swiglu.py +++ b/src/tilegym/ops/cutile/swiglu.py @@ -36,7 +36,6 @@ def swiglu_forward_kernel(a, b, c, TILE_SIZE: ct.Constant[int]): a_tile = ct.gather(a, (row, offsets), check_bounds=True, padding_value=0.0) b_tile = ct.gather(b, (row, offsets), check_bounds=True, padding_value=0.0) - # Forward uses fast math knobs for throughput on Blackwell. a_tile_f32 = a_tile.astype(ct.float32) c_tile = silu(a_tile_f32).astype(a.dtype) * b_tile ct.scatter(c, (row, offsets), c_tile, check_bounds=True) From 3476f8a4e66d8b2a039d43bc23dd21282e7555ca Mon Sep 17 00:00:00 2001 From: Aghilan Nathan Date: Mon, 23 Feb 2026 06:57:41 +0000 Subject: [PATCH 6/6] feat: add bfloat16 and float32 benchmarks --- tests/benchmark/bench_swiglu.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/tests/benchmark/bench_swiglu.py b/tests/benchmark/bench_swiglu.py index aa2d378..c9e2fc0 100644 --- a/tests/benchmark/bench_swiglu.py +++ b/tests/benchmark/bench_swiglu.py @@ -38,7 +38,7 @@ def get_supported_backends(): return [p for p in ALL_BACKENDS if p is not None] -def create_benchmark_config(batch_size, M): +def create_benchmark_config(batch_size, M, dtype): """Create a benchmark configuration for given parameters""" available_backends = get_supported_backends() if not available_backends: @@ -46,6 +46,7 @@ def create_benchmark_config(batch_size, M): backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] return triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(10, 15)], # 1024 to 16384 @@ -54,19 +55,21 @@ def create_benchmark_config(batch_size, M): line_names=list(names), styles=list(styles), ylabel="GB/s", - plot_name=f"swiglu-batch{batch_size}-M{M}-GBps", + plot_name=f"swiglu-batch{batch_size}-M{M}-{dtype_name}-GBps", args={ "batch_size": batch_size, "M": M, + "dtype": dtype, }, ) @triton.testing.perf_report( [ - create_benchmark_config(batch_size, M) - for batch_size in [1, 8] # Different batch sizes + create_benchmark_config(batch_size, M, dtype) + for batch_size in [1, 4, 8] # Different batch sizes for M in [128, 4096] # Different rows + for dtype in [torch.float16, torch.bfloat16, torch.float32] ] ) def bench_swiglu( @@ -74,18 +77,22 @@ def bench_swiglu( M, N, backend, + dtype, device=DEVICE, ): - dtype = torch.float16 - # Generate input data: two tensors for SwiGLU operation a = torch.randn(batch_size, M, N, device=device, dtype=dtype) b = torch.randn(batch_size, M, N, device=device, dtype=dtype) # Use unified dispatch system fn = lambda: tilegym.ops.get_swiglu(backend=backend)(a, b) - ref = lambda: reference_swiglu(a, b) - torch.testing.assert_close(fn(), ref(), atol=1e-2, rtol=1e-2) + if dtype is torch.float32: + ref = lambda: reference_swiglu(a, b) + atol, rtol = 1e-5, 1e-5 + else: + ref = lambda: reference_swiglu(a, b) + atol, rtol = 1e-2, 1e-2 + torch.testing.assert_close(fn(), ref(), atol=atol, rtol=rtol) # Benchmark the function ms = triton.testing.do_bench_cudagraph(fn)