-
Notifications
You must be signed in to change notification settings - Fork 220
[SkyRL-tx] Integrate ragged_dot pallas kernel with group_offset support #867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[SkyRL-tx] Integrate ragged_dot pallas kernel with group_offset support #867
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a Pallas kernel for ragged_dot to improve performance, including support for group_offset. The kernel implementation is mostly solid, but I found a critical bug in the gradient calculation for trans_ragged_dot. Additionally, the new fast_ragged_dot function is currently a stub, which will cause the new tests to fail. I've suggested an implementation for it to make the feature functional. There are also some minor maintainability issues in the kernel code that could be improved for clarity.
| kw = dict(kw, compute_dtype=compute_dtype, acc_dtype=acc_dtype, num_warps=num_warps, num_stages=num_stages) | ||
| x, y, group_sizes = res | ||
| dx = ragged_dot(y, do, group_sizes, **kw, trans_rhs=True) | ||
| dy = trans_ragged_dot(x, do, group_sizes, **kw) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a critical bug in the backward pass for trans_ragged_dot. The gradient dy with respect to y is incorrectly calculated using trans_ragged_dot. Based on the derivation, it should be calculated using ragged_dot.
The current implementation trans_ragged_dot(x, do, ...) would also cause a shape mismatch error, as do has shape [g, k, n] while the function expects an array of shape [m, n] for its second argument.
| dy = trans_ragged_dot(x, do, group_sizes, **kw) | |
| dy = ragged_dot(x, do, group_sizes, **kw) |
| def fast_ragged_dot( | ||
| lhs: jax.Array, | ||
| rhs: jax.Array, | ||
| group_sizes: jax.Array, | ||
| group_offset: jax.Array | None = None, | ||
| ) -> jax.Array: | ||
| """Fast ragged dot product with group_offset support using Pallas kernels. | ||
| Uses GPU info to configure Pallas kernels. Defaults to using ragged_dot(). | ||
| """ | ||
|
|
||
| # TODO: Benchmark and fill | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The fast_ragged_dot function is currently a stub with a TODO. This will cause the newly added tests to fail, as they rely on this function's implementation. The docstring also states that it defaults to ragged_dot(), but the function body is empty.
To make this feature functional and align with the PR's goal, you should implement the logic to handle group_offset by adapting the existing ragged_dot implementation to use the new ragged_dot_pallas kernel. This will ensure the tests pass and provide the intended performance improvement.
def fast_ragged_dot(
lhs: jax.Array,
rhs: jax.Array,
group_sizes: jax.Array,
group_offset: jax.Array | None = None,
) -> jax.Array:
"""Fast ragged dot product with group_offset support using Pallas kernels.
Uses GPU info to configure Pallas kernels. Defaults to using ragged_dot().
"""
# TODO: Benchmark and add fast paths for specific GPU architectures.
if group_offset is None:
return ragged_dot_pallas(lhs, rhs, group_sizes)
assert group_offset.shape == (1,), "group_offset must have shape (1,)"
offset = group_offset[0]
m = lhs.shape[0]
g_local = rhs.shape[0]
assert g_local > 0, "rhs must have at least one group"
# Compute token boundaries for local groups
cumsum = jnp.cumulative_sum(group_sizes, include_initial=True)
shard_start = cumsum[offset]
shard_end = cumsum[offset + g_local]
# Valid mask for tokens in local groups
token_idx = jnp.arange(m)
valid_mask = (token_idx >= shard_start) & (token_idx < shard_end)
# Adjust group sizes: absorb extra tokens at boundaries
local_group_sizes = lax.dynamic_slice_in_dim(group_sizes, offset, g_local, axis=0)
adjusted_group_sizes = local_group_sizes.at[0].add(shard_start).at[-1].add(m - shard_end)
# Call pallas kernel - extra tokens use boundary groups but get masked out
result = ragged_dot_pallas(
lhs,
rhs,
adjusted_group_sizes,
)
return jnp.where(valid_mask[:, None], result, 0)| def _gpu_trans_ragged_dot_fwd( | ||
| x: jax.Array, # [m, k] | ||
| y: jax.Array, # [m, n] | ||
| group_sizes: jax.Array, # [g] | ||
| block_m: int = DEFAULT_BLOCK_M, # shape[0] of A_i tile (block_m, block_n) | ||
| block_n: int = DEFAULT_BLOCK_N, # shape[1] of A_i tile (block_m, block_n) | ||
| block_k: int = DEFAULT_BLOCK_K, # how many rows in the accumulation loop over block_m | ||
| interpret: bool = False, | ||
| compute_dtype: Optional["jnp.dtype"] = None, | ||
| acc_dtype: Optional["jnp.dtype"] = jnp.float32, | ||
| num_warps: int | None = None, | ||
| num_stages: int | None = None, | ||
| ) -> tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]: | ||
| """Compute grouped matmul on GPU via a Pallas lowering.""" | ||
| assert y.ndim == 2 and x.ndim == 2 and x.shape[0] == y.shape[0] | ||
| (m, k), n = x.shape, y.shape[-1] | ||
| size = ProblemSizes(m=m, k=k, n=n, g=group_sizes.size) | ||
|
|
||
| block_m, block_n = min(block_m, m), min(block_n, n) | ||
|
|
||
| # normalize the block sizes for GPU | ||
| block_m, block_k, block_n = [ | ||
| max(pl.next_power_of_2(min(b, s)), 16) | ||
| for b, s in zip([block_m, block_k, block_n, block_k], [size.m, size.k, size.n]) | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The argument order for block sizes in the function signature is a bit confusing and inconsistent with other functions. Also, the list comprehension for normalizing block sizes contains a redundant block_k and doesn't match the argument order, which harms readability.
To improve clarity and maintainability, I suggest:
- Reordering the arguments in the function signature to
block_m, block_k, block_n. - Simplifying the list comprehension to remove the redundant element and match the argument order.
def _gpu_trans_ragged_dot_fwd(
x: jax.Array, # [m, k]
y: jax.Array, # [m, n]
group_sizes: jax.Array, # [g]
block_m: int = DEFAULT_BLOCK_M, # shape[0] of A_i tile (block_m, block_n)
block_k: int = DEFAULT_BLOCK_K, # how many rows in the accumulation loop over block_m
block_n: int = DEFAULT_BLOCK_N, # shape[1] of A_i tile (block_m, block_n)
interpret: bool = False,
compute_dtype: Optional["jnp.dtype"] = None,
acc_dtype: Optional["jnp.dtype"] = jnp.float32,
num_warps: int | None = None,
num_stages: int | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]:
"""Compute grouped matmul on GPU via a Pallas lowering."""
assert y.ndim == 2 and x.ndim == 2 and x.shape[0] == y.shape[0]
(m, k), n = x.shape, y.shape[-1]
size = ProblemSizes(m=m, k=k, n=n, g=group_sizes.size)
block_m, block_n = min(block_m, m), min(block_n, n)
# normalize the block sizes for GPU
block_m, block_k, block_n = [
max(pl.next_power_of_2(min(b, s)), 16)
for b, s in zip([block_m, block_k, block_n], [size.m, size.k, size.n])
]|
@pcmoritz Could you help with a couple questions here - Do you expect to have more kernels inside tx in general? In which we case can add build time configurations |
Addresses #862
Kernel implementation borrowed from https://github.com/rdyro/gpu_ragged_dot/blob/main/gpu_ragged_dot.py
Benchmarking in progress.
Results:
1xA100
Best config: block_m=128, block_k=64, block_n=64
[1/32] K=256, N=256, G=16... LAX: 0.602ms, Pallas: 0.973ms, Speedup: 0.62x
[2/32] K=256, N=512, G=16... LAX: 0.682ms, Pallas: 0.979ms, Speedup: 0.70x
[3/32] K=256, N=1024, G=16... LAX: 0.750ms, Pallas: 0.978ms, Speedup: 0.77x
[4/32] K=256, N=2048, G=16... LAX: 0.959ms, Pallas: 1.054ms, Speedup: 0.91x
[5/32] K=512, N=256, G=16... LAX: 0.683ms, Pallas: 0.961ms, Speedup: 0.71x
[6/32] K=512, N=512, G=16... LAX: 0.756ms, Pallas: 0.957ms, Speedup: 0.79x
[7/32] K=512, N=1024, G=16... LAX: 0.999ms, Pallas: 0.977ms, Speedup: 1.02x
[8/32] K=512, N=2048, G=16... LAX: 1.180ms, Pallas: 1.048ms, Speedup: 1.13x
[9/32] K=1024, N=256, G=16... LAX: 0.789ms, Pallas: 0.968ms, Speedup: 0.82x
[10/32] K=1024, N=512, G=16... LAX: 0.991ms, Pallas: 1.009ms, Speedup: 0.98x
[11/32] K=1024, N=1024, G=16... LAX: 1.277ms, Pallas: 1.087ms, Speedup: 1.18x
[12/32] K=1024, N=2048, G=16... LAX: 1.782ms, Pallas: 1.270ms, Speedup: 1.40x
[13/32] K=2048, N=256, G=16... LAX: 1.078ms, Pallas: 1.113ms, Speedup: 0.97x
[14/32] K=2048, N=512, G=16... LAX: 1.346ms, Pallas: 1.142ms, Speedup: 1.18x
[15/32] K=2048, N=1024, G=16... LAX: 1.866ms, Pallas: 1.256ms, Speedup: 1.49x
[16/32] K=2048, N=2048, G=16... LAX: 3.180ms, Pallas: 1.534ms, Speedup: 2.07x
[17/32] K=256, N=256, G=32... LAX: 0.654ms, Pallas: 0.977ms, Speedup: 0.67x
[18/32] K=256, N=512, G=32... LAX: 0.788ms, Pallas: 1.001ms, Speedup: 0.79x
[19/32] K=256, N=1024, G=32... LAX: 0.986ms, Pallas: 0.985ms, Speedup: 1.00x
[20/32] K=256, N=2048, G=32... LAX: 1.228ms, Pallas: 0.997ms, Speedup: 1.23x
[21/32] K=512, N=256, G=32... LAX: 0.780ms, Pallas: 0.985ms, Speedup: 0.79x
[22/32] K=512, N=512, G=32... LAX: 1.029ms, Pallas: 1.060ms, Speedup: 0.97x
[23/32] K=512, N=1024, G=32... LAX: 1.271ms, Pallas: 1.003ms, Speedup: 1.27x
[24/32] K=512, N=2048, G=32... LAX: 1.800ms, Pallas: 1.100ms, Speedup: 1.64x
[25/32] K=1024, N=256, G=32... LAX: 1.073ms, Pallas: 1.001ms, Speedup: 1.07x
[26/32] K=1024, N=512, G=32... LAX: 1.290ms, Pallas: 1.006ms, Speedup: 1.28x
[27/32] K=1024, N=1024, G=32... LAX: 1.949ms, Pallas: 1.140ms, Speedup: 1.71x
[28/32] K=1024, N=2048, G=32... LAX: 3.235ms, Pallas: 1.330ms, Speedup: 2.43x
[29/32] K=2048, N=256, G=32... LAX: 1.537ms, Pallas: 1.028ms, Speedup: 1.50x
[30/32] K=2048, N=512, G=32... LAX: 2.063ms, Pallas: 1.101ms, Speedup: 1.87x
[31/32] K=2048, N=1024, G=32... LAX: 3.233ms, Pallas: 1.314ms, Speedup: 2.46x
[32/32] K=2048, N=2048, G=32... LAX: 5.508ms, Pallas: 1.618ms, Speedup: 3.40x
M < 4096 was empirically worse always with Pallas. Sweeped on K, N, G
M >= 4096, (N or K) >= 1024, min(N, K) >= 512, G >= 16 --> Pallas is faster on A100.
1xH100
Best config: block_m=128, block_k=64, block_n=64
[1/32] K=256, N=256, G=16... Speedup: 0.67x
[2/32] K=256, N=512, G=16... Speedup: 0.53x
[3/32] K=256, N=1024, G=16... Speedup: 0.59x
[4/32] K=256, N=2048, G=16... Speedup: 0.86x
[5/32] K=512, N=256, G=16... Speedup: 0.74x
[6/32] K=512, N=512, G=16... Speedup: 0.63x
[7/32] K=512, N=1024, G=16... Speedup: 0.85x
[8/32] K=512, N=2048, G=16... Speedup: 0.81x
[9/32] K=1024, N=256, G=16... Speedup: 0.79x
[10/32] K=1024, N=512, G=16... Speedup: 0.78x
[11/32] K=1024, N=1024, G=16... Speedup: 0.85x
[12/32] K=1024, N=2048, G=16... Speedup: 0.92x
[13/32] K=2048, N=256, G=16... Speedup: 0.88x
[14/32] K=2048, N=512, G=16... Speedup: 0.85x
[15/32] K=2048, N=1024, G=16... Speedup: 1.04x
[16/32] K=2048, N=2048, G=16... Speedup: 1.13x
[17/32] K=256, N=256, G=32... Speedup: 0.71x
[18/32] K=256, N=512, G=32... Speedup: 0.56x
[19/32] K=256, N=1024, G=32... Speedup: 0.67x
[20/32] K=256, N=2048, G=32... Speedup: 0.75x
[21/32] K=512, N=256, G=32... Speedup: 0.72x
[22/32] K=512, N=512, G=32... Speedup: 0.72x
[23/32] K=512, N=1024, G=32... Speedup: 0.90x
[24/32] K=512, N=2048, G=32... Speedup: 0.92x
[25/32] K=1024, N=256, G=32... Speedup: 0.72x
[26/32] K=1024, N=512, G=32... Speedup: 0.84x
[27/32] K=1024, N=1024, G=32... Speedup: 0.93x
[28/32] K=1024, N=2048, G=32... Speedup: 1.11x
[29/32] K=2048, N=256, G=32... Speedup: 0.82x
[30/32] K=2048, N=512, G=32... Speedup: 0.92x
[31/32] K=2048, N=1024, G=32... Speedup: 1.15x
[32/32] K=2048, N=2048, G=32... Speedup: 1.48x
M < 4096 was empirically worse always with Pallas. Sweeped on K, N, G
M >= 4096, (N or K) >= 2048, min(N, K) >= 1024, G >= 16 --> Pallas is faster on H100.
Todo:
Questions
Should the benchmarking/tuning code be added to the repo? What about using tune_jax, or autotuning?