From 1ae42749fa0365ed1e382b4088d4bb4d1c207005 Mon Sep 17 00:00:00 2001 From: Anri Lombard Date: Wed, 14 Jan 2026 20:32:49 +0200 Subject: [PATCH] Fix FFT failure for array sizes 2^21 and 2^22 Fixes #1800 The FFT was failing for arrays of size 2^21 and 2^22 with a "kernel not found" error because the four-step decomposition was creating sub-FFTs (n1=2048 or 4096) that exceeded Metal's threadgroup memory limit. This fix lowers the four-step FFT threshold from 4096 to 1024, forcing recursive decomposition earlier and ensuring all constituent FFTs fit within Metal's 32KB threadgroup memory limit. Changes: - Added MAX_SAFE_FFT_SIZE constant (1024) - Updated plan_fft to use MAX_SAFE_FFT_SIZE instead of MAX_STOCKHAM_FFT_SIZE - Added test case for 2^21, 2^22, 2^23 to prevent regression --- mlx/backend/metal/fft.cpp | 3 ++- python/tests/test_fft.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index d99e1badb3..effca1014f 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -20,6 +20,7 @@ namespace mlx::core { using MTLFC = std::tuple; #define MAX_STOCKHAM_FFT_SIZE 4096 +#define MAX_SAFE_FFT_SIZE 1024 #define MAX_RADER_FFT_SIZE 2048 #define MAX_BLUESTEIN_FFT_SIZE 2048 // Threadgroup memory batching improves throughput for small n @@ -121,7 +122,7 @@ FFTPlan plan_fft(int n) { int remaining_n = n; // Four Step FFT when N is too large for shared mem. - if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) { + if (n > MAX_SAFE_FFT_SIZE && is_power_of_2(n)) { // For power's of two we have a fast, no transpose four step implementation. plan.four_step = true; // Rough heuristic for choosing faster powers of two when we can diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 07ab626722..641b7989d9 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -318,6 +318,14 @@ def g(x): dgdx = torch.func.grad(g)(torch.tensor(x)) self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + def test_fft_large_powers_of_two(self): + for power in [21, 22, 23]: + size = 2**power + x = mx.ones(size, dtype=mx.complex64) + result = mx.fft.fft(x, stream=mx.gpu) + mx.eval(result) + self.assertEqual(result.shape[0], size) + if __name__ == "__main__": mlx_tests.MLXTestRunner()