Skip to content

Expose unroll parameter in make_ring for GPU acceleration#17

Open
daniel-om-weber wants to merge 1 commit intosimon-bachhuber:mainfrom
daniel-om-weber:feat/scan-unroll-parameter
Open

Expose unroll parameter in make_ring for GPU acceleration#17
daniel-om-weber wants to merge 1 commit intosimon-bachhuber:mainfrom
daniel-om-weber:feat/scan-unroll-parameter

Conversation

@daniel-om-weber
Copy link

@daniel-om-weber daniel-om-weber commented Mar 4, 2026

Summary

  • Adds an unroll parameter (default 1) to both make_ring() and rnno_v1_forward_factory(), passed through to hk.dynamic_unroll
  • On GPU, lax.scan launches a separate kernel per iteration. Unrolling lets XLA fuse multiple iterations into one kernel, reducing launch overhead
  • Fully backward-compatible: default unroll=1 preserves existing behavior

Benchmark (RING, RTX 4080)

T=1000, lam=[-1,0,1], H=400, D=200:

unroll forward fwd speedup backward bwd speedup
1 46.4ms 1.00x 482.5ms 1.00x
5 34.4ms 1.35x 120.2ms 4.01x
10 32.7ms 1.42x 119.7ms 4.03x
20 32.4ms 1.43x 114.0ms 4.23x
50 31.6ms 1.47x 118.0ms 4.09x

Sweet spot is unroll=20: ~1.4x forward, ~4.2x backward speedup with reasonable one-time compile overhead (~20s vs ~4s baseline).

Usage

# RING
model = RING(lam=lam, unroll=20)

# RNNO
model = RING(forward_factory=rnno_v1_forward_factory, lam=lam, unroll=20)

…ration

Pass through the `unroll` kwarg to `hk.dynamic_unroll` in both
`make_ring` and `rnno_v1_forward_factory`, which controls how many
`lax.scan` iterations are unrolled into a single XLA kernel on GPU.
This reduces kernel launch overhead, especially during backprop.

Default remains `unroll=1` (no behavior change). Users can opt in via
`make_ring(lam=lam, unroll=20)` or `RING(lam=lam, unroll=20)`.
@daniel-om-weber daniel-om-weber force-pushed the feat/scan-unroll-parameter branch from 27118cb to 2964b54 Compare March 4, 2026 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant