Skip to content

Conversation

@guillaume-osmo
Copy link

@guillaume-osmo guillaume-osmo commented Feb 2, 2026

Proposed changes

Goal is to improve the GRU and LSTM speed with Metal. Now the GRU is faster than LSTM as expected.

python benchmarks/python/compare_recurrent_speed.py

======================================================================
Full layer timings (ms)
======================================================================
  Layer           legacy GPU      fast GPU    legacy CPU      fast CPU
  nn.GRU               1.889         1.137         2.431         2.438
  nn.LSTM             1.619         1.186         4.649         4.586

Speedup (fast vs legacy) on same device:
  GPU  GRU:  1.66x
  GPU  LSTM: 1.37x
  CPU  GRU:  1.00x
  CPU  LSTM: 1.01x

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

…marks

- Add fast_gru_cell / fast_lstm_cell Metal kernels and C++ dispatch
- GRU kernel fix: n gate uses h_proj_n (not h_prev)
- Python: MLX_RNN_IMPL (legacy/fast), optional zeros_like materialize for GRU broadcast
- fast.h: MLX_API on fast:: declarations for shared lib
- Test: test_recurrent_fast_vs_legacy (save ref from legacy, load+to_contiguous in fast, compare)
- Benchmarks: recurrent_bench.py, compare_recurrent_speed.py
- RECURRENT_VERSIONS.md: doc and examples
@guillaume-osmo guillaume-osmo changed the title Rnn speedup only [METAL] RNN speedup Feb 2, 2026
@guillaume-osmo guillaume-osmo changed the title [METAL] RNN speedup [Metal] [Performance] RNN speedup Feb 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant