diff --git a/src/tilegym/ops/cutile/swiglu.py b/src/tilegym/ops/cutile/swiglu.py index a73c111..4f15f92 100644 --- a/src/tilegym/ops/cutile/swiglu.py +++ b/src/tilegym/ops/cutile/swiglu.py @@ -6,6 +6,7 @@ import cuda.tile as ct import torch import torch.nn as nn +from cuda.tile._numeric_semantics import RoundingMode as RMd from tilegym.backend import register_impl @@ -15,28 +16,29 @@ def sigmoid(x): - return 1.0 / (1.0 + ct.exp(-x)) + denom = ct.add(1.0, ct.exp(-x), flush_to_zero=True) + return ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=RMd.APPROX) def silu(x): - return x * sigmoid(x) + return ct.mul(x, sigmoid(x), flush_to_zero=True) + + +def ceildiv(a, b): + return -(a // -b) @ct.kernel def swiglu_forward_kernel(a, b, c, TILE_SIZE: ct.Constant[int]): row = ct.bid(0) - col = ct.bid(1) - - a_tile = ct.load(a, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO) - b_tile = ct.load(b, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO) + offsets = ct.arange(TILE_SIZE, dtype=ct.int32) - # Sigmoid requires type float32 - c_tile = silu(a_tile.astype(ct.float32)).astype(a.dtype) * b_tile - ct.store(c, index=(row, col), tile=c_tile) + a_tile = ct.gather(a, (row, offsets), check_bounds=True, padding_value=0.0) + b_tile = ct.gather(b, (row, offsets), check_bounds=True, padding_value=0.0) - -def ceildiv(a, b): - return -(a // -b) + a_tile_f32 = a_tile.astype(ct.float32) + c_tile = silu(a_tile_f32).astype(a.dtype) * b_tile + ct.scatter(c, (row, offsets), c_tile, check_bounds=True) def swiglu_forward(a, b): @@ -51,18 +53,16 @@ def swiglu_forward(a, b): c = torch.empty_like(a) n_rows = a.shape[0] - NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count - TILE_N = ceildiv(NUM_SMS, n_rows) - TILE_SIZE = next_power_of_2(int(n_cols / TILE_N)) - grid = (n_rows, ceildiv(n_cols, TILE_SIZE), 1) + TILE_SIZE = next_power_of_2(n_cols) + grid = (n_rows,) ct.launch( torch.cuda.current_stream(), grid, swiglu_forward_kernel, ( - a.data, - b.data, - c.data, + a, + b, + c, TILE_SIZE, ), ) @@ -89,7 +89,7 @@ def swiglu_backward_kernel(dc, a, b, da, db, TILE_SIZE: ct.Constant[int]): b_tile_f32 = b_tile.astype(ct.float32) # Compute sigmoid(a) and silu(a) - sigmoid_a = sigmoid(a_tile_f32) + sigmoid_a = 1.0 / (1.0 + ct.exp(-a_tile_f32)) silu_a = a_tile_f32 * sigmoid_a # db = dc * silu(a) diff --git a/tests/benchmark/bench_swiglu.py b/tests/benchmark/bench_swiglu.py index aa2d378..c9e2fc0 100644 --- a/tests/benchmark/bench_swiglu.py +++ b/tests/benchmark/bench_swiglu.py @@ -38,7 +38,7 @@ def get_supported_backends(): return [p for p in ALL_BACKENDS if p is not None] -def create_benchmark_config(batch_size, M): +def create_benchmark_config(batch_size, M, dtype): """Create a benchmark configuration for given parameters""" available_backends = get_supported_backends() if not available_backends: @@ -46,6 +46,7 @@ def create_benchmark_config(batch_size, M): backends, names, styles = zip(*available_backends) + dtype_name = str(dtype).split(".")[-1] return triton.testing.Benchmark( x_names=["N"], x_vals=[2**i for i in range(10, 15)], # 1024 to 16384 @@ -54,19 +55,21 @@ def create_benchmark_config(batch_size, M): line_names=list(names), styles=list(styles), ylabel="GB/s", - plot_name=f"swiglu-batch{batch_size}-M{M}-GBps", + plot_name=f"swiglu-batch{batch_size}-M{M}-{dtype_name}-GBps", args={ "batch_size": batch_size, "M": M, + "dtype": dtype, }, ) @triton.testing.perf_report( [ - create_benchmark_config(batch_size, M) - for batch_size in [1, 8] # Different batch sizes + create_benchmark_config(batch_size, M, dtype) + for batch_size in [1, 4, 8] # Different batch sizes for M in [128, 4096] # Different rows + for dtype in [torch.float16, torch.bfloat16, torch.float32] ] ) def bench_swiglu( @@ -74,18 +77,22 @@ def bench_swiglu( M, N, backend, + dtype, device=DEVICE, ): - dtype = torch.float16 - # Generate input data: two tensors for SwiGLU operation a = torch.randn(batch_size, M, N, device=device, dtype=dtype) b = torch.randn(batch_size, M, N, device=device, dtype=dtype) # Use unified dispatch system fn = lambda: tilegym.ops.get_swiglu(backend=backend)(a, b) - ref = lambda: reference_swiglu(a, b) - torch.testing.assert_close(fn(), ref(), atol=1e-2, rtol=1e-2) + if dtype is torch.float32: + ref = lambda: reference_swiglu(a, b) + atol, rtol = 1e-5, 1e-5 + else: + ref = lambda: reference_swiglu(a, b) + atol, rtol = 1e-2, 1e-2 + torch.testing.assert_close(fn(), ref(), atol=atol, rtol=rtol) # Benchmark the function ms = triton.testing.do_bench_cudagraph(fn)