Expose unroll parameter in make_ring for GPU acceleration#17
Open
daniel-om-weber wants to merge 1 commit intosimon-bachhuber:mainfrom
Open
Expose unroll parameter in make_ring for GPU acceleration#17daniel-om-weber wants to merge 1 commit intosimon-bachhuber:mainfrom
unroll parameter in make_ring for GPU acceleration#17daniel-om-weber wants to merge 1 commit intosimon-bachhuber:mainfrom
Conversation
…ration Pass through the `unroll` kwarg to `hk.dynamic_unroll` in both `make_ring` and `rnno_v1_forward_factory`, which controls how many `lax.scan` iterations are unrolled into a single XLA kernel on GPU. This reduces kernel launch overhead, especially during backprop. Default remains `unroll=1` (no behavior change). Users can opt in via `make_ring(lam=lam, unroll=20)` or `RING(lam=lam, unroll=20)`.
27118cb to
2964b54
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
unrollparameter (default1) to bothmake_ring()andrnno_v1_forward_factory(), passed through tohk.dynamic_unrolllax.scanlaunches a separate kernel per iteration. Unrolling lets XLA fuse multiple iterations into one kernel, reducing launch overheadunroll=1preserves existing behaviorBenchmark (RING, RTX 4080)
T=1000, lam=[-1,0,1], H=400, D=200:
Sweet spot is
unroll=20: ~1.4x forward, ~4.2x backward speedup with reasonable one-time compile overhead (~20s vs ~4s baseline).Usage