-
Notifications
You must be signed in to change notification settings - Fork 50
feat: swiglu forward optimizations #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ad6272b
a0c86c3
dd5d0ff
ea516f1
0461595
3476f8a
c856ac8
285b2cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import cuda.tile as ct | ||
| import torch | ||
| import torch.nn as nn | ||
| from cuda.tile._numeric_semantics import RoundingMode as RMd | ||
|
|
||
| from tilegym.backend import register_impl | ||
|
|
||
|
|
@@ -15,28 +16,29 @@ | |
|
|
||
|
|
||
| def sigmoid(x): | ||
| return 1.0 / (1.0 + ct.exp(-x)) | ||
| denom = ct.add(1.0, ct.exp(-x), flush_to_zero=True) | ||
| return ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=RMd.APPROX) | ||
|
|
||
|
|
||
| def silu(x): | ||
| return x * sigmoid(x) | ||
| return ct.mul(x, sigmoid(x), flush_to_zero=True) | ||
|
|
||
|
|
||
| def ceildiv(a, b): | ||
| return -(a // -b) | ||
|
|
||
|
|
||
| @ct.kernel | ||
| def swiglu_forward_kernel(a, b, c, TILE_SIZE: ct.Constant[int]): | ||
| row = ct.bid(0) | ||
| col = ct.bid(1) | ||
|
|
||
| a_tile = ct.load(a, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO) | ||
| b_tile = ct.load(b, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO) | ||
| offsets = ct.arange(TILE_SIZE, dtype=ct.int32) | ||
|
|
||
| # Sigmoid requires type float32 | ||
| c_tile = silu(a_tile.astype(ct.float32)).astype(a.dtype) * b_tile | ||
| ct.store(c, index=(row, col), tile=c_tile) | ||
| a_tile = ct.gather(a, (row, offsets), check_bounds=True, padding_value=0.0) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good chunk of the perf improvements came from gather scatter vs load/store |
||
| b_tile = ct.gather(b, (row, offsets), check_bounds=True, padding_value=0.0) | ||
|
|
||
|
|
||
| def ceildiv(a, b): | ||
| return -(a // -b) | ||
| a_tile_f32 = a_tile.astype(ct.float32) | ||
| c_tile = silu(a_tile_f32).astype(a.dtype) * b_tile | ||
| ct.scatter(c, (row, offsets), c_tile, check_bounds=True) | ||
|
|
||
|
|
||
| def swiglu_forward(a, b): | ||
|
|
@@ -51,18 +53,16 @@ def swiglu_forward(a, b): | |
| c = torch.empty_like(a) | ||
| n_rows = a.shape[0] | ||
|
|
||
| NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count | ||
| TILE_N = ceildiv(NUM_SMS, n_rows) | ||
| TILE_SIZE = next_power_of_2(int(n_cols / TILE_N)) | ||
| grid = (n_rows, ceildiv(n_cols, TILE_SIZE), 1) | ||
| TILE_SIZE = next_power_of_2(n_cols) | ||
| grid = (n_rows,) | ||
| ct.launch( | ||
| torch.cuda.current_stream(), | ||
| grid, | ||
| swiglu_forward_kernel, | ||
| ( | ||
| a.data, | ||
| b.data, | ||
| c.data, | ||
| a, | ||
| b, | ||
| c, | ||
| TILE_SIZE, | ||
| ), | ||
| ) | ||
|
|
@@ -89,7 +89,7 @@ def swiglu_backward_kernel(dc, a, b, da, db, TILE_SIZE: ct.Constant[int]): | |
| b_tile_f32 = b_tile.astype(ct.float32) | ||
|
|
||
| # Compute sigmoid(a) and silu(a) | ||
| sigmoid_a = sigmoid(a_tile_f32) | ||
| sigmoid_a = 1.0 / (1.0 + ct.exp(-a_tile_f32)) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inlined this for now because I didn’t want to modify the backward kernel in this PR - that would require re-benchmarking it as well. I have additional optimizations planned that I’ll include in a separate PR, which will also make use of the new
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a clear comment here? The current changes make the forward and backward codes a bit confusing. You can delete the comment when the backward PR is ready. |
||
| silu_a = a_tile_f32 * sigmoid_a | ||
|
|
||
| # db = dc * silu(a) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,14 +38,15 @@ def get_supported_backends(): | |
| return [p for p in ALL_BACKENDS if p is not None] | ||
|
|
||
|
|
||
| def create_benchmark_config(batch_size, M): | ||
| def create_benchmark_config(batch_size, M, dtype): | ||
| """Create a benchmark configuration for given parameters""" | ||
| available_backends = get_supported_backends() | ||
| if not available_backends: | ||
| return None | ||
|
|
||
| backends, names, styles = zip(*available_backends) | ||
|
|
||
| dtype_name = str(dtype).split(".")[-1] | ||
| return triton.testing.Benchmark( | ||
| x_names=["N"], | ||
| x_vals=[2**i for i in range(10, 15)], # 1024 to 16384 | ||
|
|
@@ -54,38 +55,44 @@ def create_benchmark_config(batch_size, M): | |
| line_names=list(names), | ||
| styles=list(styles), | ||
| ylabel="GB/s", | ||
| plot_name=f"swiglu-batch{batch_size}-M{M}-GBps", | ||
| plot_name=f"swiglu-batch{batch_size}-M{M}-{dtype_name}-GBps", | ||
| args={ | ||
| "batch_size": batch_size, | ||
| "M": M, | ||
| "dtype": dtype, | ||
| }, | ||
| ) | ||
|
|
||
|
|
||
| @triton.testing.perf_report( | ||
| [ | ||
| create_benchmark_config(batch_size, M) | ||
| for batch_size in [1, 8] # Different batch sizes | ||
| create_benchmark_config(batch_size, M, dtype) | ||
| for batch_size in [1, 4, 8] # Different batch sizes | ||
| for M in [128, 4096] # Different rows | ||
| for dtype in [torch.float16, torch.bfloat16, torch.float32] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Most benchmarks test across various dtypes, I thought this one should too |
||
| ] | ||
| ) | ||
| def bench_swiglu( | ||
| batch_size, | ||
| M, | ||
| N, | ||
| backend, | ||
| dtype, | ||
| device=DEVICE, | ||
| ): | ||
| dtype = torch.float16 | ||
|
|
||
| # Generate input data: two tensors for SwiGLU operation | ||
| a = torch.randn(batch_size, M, N, device=device, dtype=dtype) | ||
| b = torch.randn(batch_size, M, N, device=device, dtype=dtype) | ||
|
|
||
| # Use unified dispatch system | ||
| fn = lambda: tilegym.ops.get_swiglu(backend=backend)(a, b) | ||
| ref = lambda: reference_swiglu(a, b) | ||
| torch.testing.assert_close(fn(), ref(), atol=1e-2, rtol=1e-2) | ||
| if dtype is torch.float32: | ||
| ref = lambda: reference_swiglu(a, b) | ||
| atol, rtol = 1e-5, 1e-5 | ||
| else: | ||
| ref = lambda: reference_swiglu(a, b) | ||
| atol, rtol = 1e-2, 1e-2 | ||
| torch.testing.assert_close(fn(), ref(), atol=atol, rtol=rtol) | ||
|
|
||
| # Benchmark the function | ||
| ms = triton.testing.do_bench_cudagraph(fn) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A good chunk of the savings came from
Rmd.APPROXwithout losing precision - verified via tests