Skip to content

feat: cutile kernels to replace the quack kernels#3

Merged
aghilann merged 9 commits intomainfrom
cutile
Feb 18, 2026
Merged

feat: cutile kernels to replace the quack kernels#3
aghilann merged 9 commits intomainfrom
cutile

Conversation

@aghilann
Copy link
Owner

No description provided.

root and others added 9 commits February 14, 2026 10:50
- Add native CuTile RMSNorm implementation (rms_norm_cutile.py) with
  dual-strategy forward (gather/scatter + TMA) and autotuning
- Add CuTile RMSNorm unit tests and kernel-level comparison benchmark
- Add Qwen3 8B FSDP multi-GPU benchmark script
- Fix pyproject.toml: use stable cu130 index with explicit source pinning
  to prevent markupsafe/triton resolution failures on cp312
- Fix e2e/__init__.py: remove references to deleted modules (qwen_8b_liger,
  geglu) that caused ModuleNotFoundError on import
- Disable CuTile rope in comparison_small (tileiras compiler fails on B200
  sm_100 architecture)
Rewrote CuTile RMSNorm with aggressive Blackwell B200 optimizations:

Forward kernels:
- Static-persistent 2D TMA: grid-stride loop over multi-row tiles,
  weight pre-loaded once, allow_tma=False stores (+30%), latency hints
- Gather/scatter 1D: tiled column loop for non-power-of-2 N (no padding waste)
- TMA 1-row: bulk DMA load for small M with power-of-2 N

Key optimizations:
- Heuristic config selection tuned per M/N regime (TMA occ=4 for small M,
  persistent tile_m=8 for medium M, persistent tile_m=4 for large M)
- High-padding routing: >20% waste routes to gather kernel (fixes N=5120)
- Tile size capping: max(2, min(8, 32768//TILE_N)) avoids register spills
- Python dispatch minimization: cached NUM_SMS, stream, dtype as dict key
- Rstd saved from persistent kernel for backward correctness

Benchmark results (B200, bf16, vs Quack cuteDSL):
- Average: 1.06x slower (down from ~1.75x initial)
- Best: CuTile beats Quack on 5/16 configs (M=8192: 1.06x faster)
- Worst: 1.24x slower (M=2048 N=5120, was 1.45x)

Also adds kernel-level benchmark: rms_norm_quack_vs_cutile.py
… update benchmarks

- Rename rms_norm_cutile.py → rms_norm.py (replaces old Quack-based version)
- Update all imports in __init__.py, autotune.py, and tests
- Delete temporary experiment scripts (bench_bwd_*, ncu_bwd_*, profile_*, test_dw_*)
- Update E2E benchmark charts (PyTorch vs Liger vs Bastile on Qwen3-8B)

E2E results on B200 (Qwen3-8B, bf16, batch_size=1):
  seq=1024:  Bastile +18.2% throughput vs PyTorch
  seq=2048:  Bastile +20.7% throughput vs PyTorch
  seq=4096:  Bastile +27.1% throughput vs PyTorch
  seq=8192:  Bastile +32.1% throughput vs PyTorch, 34GB less memory
  seq=16384: Bastile runs, PyTorch OOMs
  Bastile beats Liger by 1.1-1.9x across all configs
…to charts

Deleted:
- CONTEXT.md (outdated session briefing)
- test_minimal_kernel.py (temp debug script)
- tests/ops/test_geglu.py (imports nonexistent gpt_oss_moe)
- tests/ops/test_rms_norm_cutile.py (stale duplicate of test_rms_norm.py)
- tests/benchmarks/e2e/comparison_small.py (superseded by qwen_8b_seqlen.py)
- tests/benchmarks/e2e/profile_kernels.py (outdated small model profiler)

Cleaned:
- configs.py: removed unused GEGLUConfig, RMSNormConfig, SwiGLUConfig
- tests/ops/run_all.py: removed dead test_geglu reference
- tests/benchmarks/e2e/__init__.py: removed deleted file references
- tests/benchmarks/run_all.py: fixed recursive self-import bug
- tests/benchmarks/kernel/rms_norm.py: fixed outdated 4-arg call to rms_norm()

Added:
- tests/benchmarks/kernel/bench_fused_lce.py: fused LCE kernel benchmark
- qwen_8b_seqlen.py: auto-generates bar charts to assets/ after each run

Updated:
- fused_linear_cross_entropy.py: removed _ce_pytorch fallback, CuTile-only path
- Updated benchmark charts with latest results
- Compress rope.py: single kernel body + programmatic occupancy variants
  via ct.kernel(fn, occupancy=N), reuse forward kernels for backward by
  negating sin (inverse rotation identity), merge duplicated helpers
- Restore swiglu.py, rope.py, configs.py and their tests/benchmarks
- Add fused LCE unit tests, fix rms_norm test, add LICENSE
- Consolidate duplicate SM count logic into ops/utils.py
- Remove dead code, unused deps, stale benchmark files
- Fix missing inspect import in qwen_8b_seqlen.py
- Add CUDA_VISIBLE_DEVICES default to Makefile
- Simplify rope.py to single @ct.kernel (compiler-auto occupancy),
  removing autotune dependency, configs.py, and kernel dict — 119 lines
- Strip autotune.py to just clear_cache + warmup_all_kernels
- Set up ruff linter/formatter with config in pyproject.toml
- Apply ruff auto-fixes: import sorting, trailing whitespace, type upgrades
- Fix unused variables (scaled, rank, B/T unpacking)
- Remove ═══/=== separator comments from swiglu.py and benchmark files
- Restore ═══ separators in fused_linear_cross_entropy.py with no blank
  lines between closing separator and code
- Add make lint/fmt targets, add CUDA_VISIBLE_DEVICES default to Makefile
PyTorch=#EE4C2C (red), Liger=#0077B5 (LinkedIn blue), Bastile=#5CE97E (Baseten green)
@aghilann aghilann merged commit dbaa380 into main Feb 18, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant