Skip to content

Conversation

@agolajko
Copy link
Contributor

@agolajko agolajko commented Jan 15, 2026

Draft PR re #862

Replaces the Jax ragged_dot with a cuda tile implementation
Inspired by https://github.com/NVIDIA/cutile-python/blob/main/samples/MoE.py

Benchmarking cuda-tile and existing ragged_dot implementation


================================================================================
                      CUTILE vs RAGGED_DOT BENCHMARK SUITE                      
================================================================================

Small (original)
  Config: 1024 tokens × 512 hidden → 512 out, 16 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 1.739 ms
  cutile:     1.721 ms
  Speedup:    1.01x

Medium (Qwen-0.6B scale)
  Config: 2048 tokens × 1024 hidden → 1024 out, 16 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 3.222 ms
  cutile:     3.506 ms
  Speedup:    0.92x

Large (Qwen2.5-1.5B scale)
  Config: 4096 tokens × 1536 hidden → 1536 out, 32 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 9.152 ms
  cutile:     9.142 ms
  Speedup:    1.00x

Large+ (2B scale)
  Config: 4096 tokens × 2048 hidden → 2048 out, 32 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 14.459 ms
  cutile:     14.464 ms
  Speedup:    1.00x

XLarge (Llama 3 8B scale)
  Config: 8192 tokens × 4096 hidden → 4096 out, 64 experts
  ----------------------------------------------------------------------------

Benchmark Results:
  ragged_dot: 94.049 ms
  cutile:     92.585 ms
  Speedup:    1.02x

Todo:

  • Multi GPU support
  • backward pass
  • more tests
  • profile

@pcmoritz pcmoritz added the tx label Jan 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants