Skip to content

Conversation

@tanmaysachan
Copy link

@tanmaysachan tanmaysachan commented Jan 13, 2026

Addresses #862
Kernel implementation borrowed from https://github.com/rdyro/gpu_ragged_dot/blob/main/gpu_ragged_dot.py

Benchmarking in progress.

Results:

1xA100

image

Best config: block_m=128, block_k=64, block_n=64

[1/32] K=256, N=256, G=16... LAX: 0.602ms, Pallas: 0.973ms, Speedup: 0.62x
[2/32] K=256, N=512, G=16... LAX: 0.682ms, Pallas: 0.979ms, Speedup: 0.70x
[3/32] K=256, N=1024, G=16... LAX: 0.750ms, Pallas: 0.978ms, Speedup: 0.77x
[4/32] K=256, N=2048, G=16... LAX: 0.959ms, Pallas: 1.054ms, Speedup: 0.91x
[5/32] K=512, N=256, G=16... LAX: 0.683ms, Pallas: 0.961ms, Speedup: 0.71x
[6/32] K=512, N=512, G=16... LAX: 0.756ms, Pallas: 0.957ms, Speedup: 0.79x
[7/32] K=512, N=1024, G=16... LAX: 0.999ms, Pallas: 0.977ms, Speedup: 1.02x
[8/32] K=512, N=2048, G=16... LAX: 1.180ms, Pallas: 1.048ms, Speedup: 1.13x
[9/32] K=1024, N=256, G=16... LAX: 0.789ms, Pallas: 0.968ms, Speedup: 0.82x
[10/32] K=1024, N=512, G=16... LAX: 0.991ms, Pallas: 1.009ms, Speedup: 0.98x
[11/32] K=1024, N=1024, G=16... LAX: 1.277ms, Pallas: 1.087ms, Speedup: 1.18x
[12/32] K=1024, N=2048, G=16... LAX: 1.782ms, Pallas: 1.270ms, Speedup: 1.40x
[13/32] K=2048, N=256, G=16... LAX: 1.078ms, Pallas: 1.113ms, Speedup: 0.97x
[14/32] K=2048, N=512, G=16... LAX: 1.346ms, Pallas: 1.142ms, Speedup: 1.18x
[15/32] K=2048, N=1024, G=16... LAX: 1.866ms, Pallas: 1.256ms, Speedup: 1.49x
[16/32] K=2048, N=2048, G=16... LAX: 3.180ms, Pallas: 1.534ms, Speedup: 2.07x
[17/32] K=256, N=256, G=32... LAX: 0.654ms, Pallas: 0.977ms, Speedup: 0.67x
[18/32] K=256, N=512, G=32... LAX: 0.788ms, Pallas: 1.001ms, Speedup: 0.79x
[19/32] K=256, N=1024, G=32... LAX: 0.986ms, Pallas: 0.985ms, Speedup: 1.00x
[20/32] K=256, N=2048, G=32... LAX: 1.228ms, Pallas: 0.997ms, Speedup: 1.23x
[21/32] K=512, N=256, G=32... LAX: 0.780ms, Pallas: 0.985ms, Speedup: 0.79x
[22/32] K=512, N=512, G=32... LAX: 1.029ms, Pallas: 1.060ms, Speedup: 0.97x
[23/32] K=512, N=1024, G=32... LAX: 1.271ms, Pallas: 1.003ms, Speedup: 1.27x
[24/32] K=512, N=2048, G=32... LAX: 1.800ms, Pallas: 1.100ms, Speedup: 1.64x
[25/32] K=1024, N=256, G=32... LAX: 1.073ms, Pallas: 1.001ms, Speedup: 1.07x
[26/32] K=1024, N=512, G=32... LAX: 1.290ms, Pallas: 1.006ms, Speedup: 1.28x
[27/32] K=1024, N=1024, G=32... LAX: 1.949ms, Pallas: 1.140ms, Speedup: 1.71x
[28/32] K=1024, N=2048, G=32... LAX: 3.235ms, Pallas: 1.330ms, Speedup: 2.43x
[29/32] K=2048, N=256, G=32... LAX: 1.537ms, Pallas: 1.028ms, Speedup: 1.50x
[30/32] K=2048, N=512, G=32... LAX: 2.063ms, Pallas: 1.101ms, Speedup: 1.87x
[31/32] K=2048, N=1024, G=32... LAX: 3.233ms, Pallas: 1.314ms, Speedup: 2.46x
[32/32] K=2048, N=2048, G=32... LAX: 5.508ms, Pallas: 1.618ms, Speedup: 3.40x

M < 4096 was empirically worse always with Pallas. Sweeped on K, N, G

M >= 4096, (N or K) >= 1024, min(N, K) >= 512, G >= 16 --> Pallas is faster on A100.

1xH100

Best config: block_m=128, block_k=64, block_n=64

[1/32] K=256, N=256, G=16... Speedup: 0.67x
[2/32] K=256, N=512, G=16... Speedup: 0.53x
[3/32] K=256, N=1024, G=16... Speedup: 0.59x
[4/32] K=256, N=2048, G=16... Speedup: 0.86x
[5/32] K=512, N=256, G=16... Speedup: 0.74x
[6/32] K=512, N=512, G=16... Speedup: 0.63x
[7/32] K=512, N=1024, G=16... Speedup: 0.85x
[8/32] K=512, N=2048, G=16... Speedup: 0.81x
[9/32] K=1024, N=256, G=16... Speedup: 0.79x
[10/32] K=1024, N=512, G=16... Speedup: 0.78x
[11/32] K=1024, N=1024, G=16... Speedup: 0.85x
[12/32] K=1024, N=2048, G=16... Speedup: 0.92x
[13/32] K=2048, N=256, G=16... Speedup: 0.88x
[14/32] K=2048, N=512, G=16... Speedup: 0.85x
[15/32] K=2048, N=1024, G=16... Speedup: 1.04x
[16/32] K=2048, N=2048, G=16... Speedup: 1.13x
[17/32] K=256, N=256, G=32... Speedup: 0.71x
[18/32] K=256, N=512, G=32... Speedup: 0.56x
[19/32] K=256, N=1024, G=32... Speedup: 0.67x
[20/32] K=256, N=2048, G=32... Speedup: 0.75x
[21/32] K=512, N=256, G=32... Speedup: 0.72x
[22/32] K=512, N=512, G=32... Speedup: 0.72x
[23/32] K=512, N=1024, G=32... Speedup: 0.90x
[24/32] K=512, N=2048, G=32... Speedup: 0.92x
[25/32] K=1024, N=256, G=32... Speedup: 0.72x
[26/32] K=1024, N=512, G=32... Speedup: 0.84x
[27/32] K=1024, N=1024, G=32... Speedup: 0.93x
[28/32] K=1024, N=2048, G=32... Speedup: 1.11x
[29/32] K=2048, N=256, G=32... Speedup: 0.82x
[30/32] K=2048, N=512, G=32... Speedup: 0.92x
[31/32] K=2048, N=1024, G=32... Speedup: 1.15x
[32/32] K=2048, N=2048, G=32... Speedup: 1.48x

M < 4096 was empirically worse always with Pallas. Sweeped on K, N, G

M >= 4096, (N or K) >= 2048, min(N, K) >= 1024, G >= 16 --> Pallas is faster on H100.

Todo:

  • 1xH100 benchmarking
  • Add tuned block sizes for known architectures, fallback to default jax impl for others in fast_ragged_dot().
  • Add autotune? Tuning notebook/script?
  • More tests

Questions

Should the benchmarking/tuning code be added to the repo? What about using tune_jax, or autotuning?

@tanmaysachan tanmaysachan changed the title Integrate ragged_dot pallas kernel with group_offset support [SkyRL-tx] Integrate ragged_dot pallas kernel with group_offset support Jan 13, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a Pallas kernel for ragged_dot to improve performance, including support for group_offset. The kernel implementation is mostly solid, but I found a critical bug in the gradient calculation for trans_ragged_dot. Additionally, the new fast_ragged_dot function is currently a stub, which will cause the new tests to fail. I've suggested an implementation for it to make the feature functional. There are also some minor maintainability issues in the kernel code that could be improved for clarity.

kw = dict(kw, compute_dtype=compute_dtype, acc_dtype=acc_dtype, num_warps=num_warps, num_stages=num_stages)
x, y, group_sizes = res
dx = ragged_dot(y, do, group_sizes, **kw, trans_rhs=True)
dy = trans_ragged_dot(x, do, group_sizes, **kw)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical bug in the backward pass for trans_ragged_dot. The gradient dy with respect to y is incorrectly calculated using trans_ragged_dot. Based on the derivation, it should be calculated using ragged_dot.

The current implementation trans_ragged_dot(x, do, ...) would also cause a shape mismatch error, as do has shape [g, k, n] while the function expects an array of shape [m, n] for its second argument.

Suggested change
dy = trans_ragged_dot(x, do, group_sizes, **kw)
dy = ragged_dot(x, do, group_sizes, **kw)

Comment on lines 62 to 74
def fast_ragged_dot(
lhs: jax.Array,
rhs: jax.Array,
group_sizes: jax.Array,
group_offset: jax.Array | None = None,
) -> jax.Array:
"""Fast ragged dot product with group_offset support using Pallas kernels.
Uses GPU info to configure Pallas kernels. Defaults to using ragged_dot().
"""

# TODO: Benchmark and fill

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The fast_ragged_dot function is currently a stub with a TODO. This will cause the newly added tests to fail, as they rely on this function's implementation. The docstring also states that it defaults to ragged_dot(), but the function body is empty.

To make this feature functional and align with the PR's goal, you should implement the logic to handle group_offset by adapting the existing ragged_dot implementation to use the new ragged_dot_pallas kernel. This will ensure the tests pass and provide the intended performance improvement.

def fast_ragged_dot(
    lhs: jax.Array,
    rhs: jax.Array,
    group_sizes: jax.Array,
    group_offset: jax.Array | None = None,
) -> jax.Array:
    """Fast ragged dot product with group_offset support using Pallas kernels.

    Uses GPU info to configure Pallas kernels. Defaults to using ragged_dot().
    """
    # TODO: Benchmark and add fast paths for specific GPU architectures.
    if group_offset is None:
        return ragged_dot_pallas(lhs, rhs, group_sizes)

    assert group_offset.shape == (1,), "group_offset must have shape (1,)"
    offset = group_offset[0]
    m = lhs.shape[0]
    g_local = rhs.shape[0]

    assert g_local > 0, "rhs must have at least one group"

    # Compute token boundaries for local groups
    cumsum = jnp.cumulative_sum(group_sizes, include_initial=True)
    shard_start = cumsum[offset]
    shard_end = cumsum[offset + g_local]

    # Valid mask for tokens in local groups
    token_idx = jnp.arange(m)
    valid_mask = (token_idx >= shard_start) & (token_idx < shard_end)

    # Adjust group sizes: absorb extra tokens at boundaries
    local_group_sizes = lax.dynamic_slice_in_dim(group_sizes, offset, g_local, axis=0)
    adjusted_group_sizes = local_group_sizes.at[0].add(shard_start).at[-1].add(m - shard_end)

    # Call pallas kernel - extra tokens use boundary groups but get masked out
    result = ragged_dot_pallas(
        lhs,
        rhs,
        adjusted_group_sizes,
    )

    return jnp.where(valid_mask[:, None], result, 0)

Comment on lines +237 to +261
def _gpu_trans_ragged_dot_fwd(
x: jax.Array, # [m, k]
y: jax.Array, # [m, n]
group_sizes: jax.Array, # [g]
block_m: int = DEFAULT_BLOCK_M, # shape[0] of A_i tile (block_m, block_n)
block_n: int = DEFAULT_BLOCK_N, # shape[1] of A_i tile (block_m, block_n)
block_k: int = DEFAULT_BLOCK_K, # how many rows in the accumulation loop over block_m
interpret: bool = False,
compute_dtype: Optional["jnp.dtype"] = None,
acc_dtype: Optional["jnp.dtype"] = jnp.float32,
num_warps: int | None = None,
num_stages: int | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]:
"""Compute grouped matmul on GPU via a Pallas lowering."""
assert y.ndim == 2 and x.ndim == 2 and x.shape[0] == y.shape[0]
(m, k), n = x.shape, y.shape[-1]
size = ProblemSizes(m=m, k=k, n=n, g=group_sizes.size)

block_m, block_n = min(block_m, m), min(block_n, n)

# normalize the block sizes for GPU
block_m, block_k, block_n = [
max(pl.next_power_of_2(min(b, s)), 16)
for b, s in zip([block_m, block_k, block_n, block_k], [size.m, size.k, size.n])
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The argument order for block sizes in the function signature is a bit confusing and inconsistent with other functions. Also, the list comprehension for normalizing block sizes contains a redundant block_k and doesn't match the argument order, which harms readability.

To improve clarity and maintainability, I suggest:

  1. Reordering the arguments in the function signature to block_m, block_k, block_n.
  2. Simplifying the list comprehension to remove the redundant element and match the argument order.
def _gpu_trans_ragged_dot_fwd(
    x: jax.Array,  # [m, k]
    y: jax.Array,  # [m, n]
    group_sizes: jax.Array,  # [g]
    block_m: int = DEFAULT_BLOCK_M,  # shape[0] of A_i tile (block_m, block_n)
    block_k: int = DEFAULT_BLOCK_K,  # how many rows in the accumulation loop over block_m
    block_n: int = DEFAULT_BLOCK_N,  # shape[1] of A_i tile (block_m, block_n)
    interpret: bool = False,
    compute_dtype: Optional["jnp.dtype"] = None,
    acc_dtype: Optional["jnp.dtype"] = jnp.float32,
    num_warps: int | None = None,
    num_stages: int | None = None,
) -> tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array]]:
    """Compute grouped matmul on GPU via a Pallas lowering."""
    assert y.ndim == 2 and x.ndim == 2 and x.shape[0] == y.shape[0]
    (m, k), n = x.shape, y.shape[-1]
    size = ProblemSizes(m=m, k=k, n=n, g=group_sizes.size)

    block_m, block_n = min(block_m, m), min(block_n, n)

    # normalize the block sizes for GPU
    block_m, block_k, block_n = [
        max(pl.next_power_of_2(min(b, s)), 16)
        for b, s in zip([block_m, block_k, block_n], [size.m, size.k, size.n])
    ]

@pcmoritz pcmoritz added the tx label Jan 13, 2026
@tanmaysachan
Copy link
Author

tanmaysachan commented Jan 14, 2026

@pcmoritz Could you help with a couple questions here -
Should we use something like tune-jax to adapt to user's hardware at build - or would it be preferable to just have if-else blocks

Do you expect to have more kernels inside tx in general? In which we case can add build time configurations

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants