diff --git a/mlx/backend/metal/fft.cpp b/mlx/backend/metal/fft.cpp index d99e1badb3..effca1014f 100644 --- a/mlx/backend/metal/fft.cpp +++ b/mlx/backend/metal/fft.cpp @@ -20,6 +20,7 @@ namespace mlx::core { using MTLFC = std::tuple; #define MAX_STOCKHAM_FFT_SIZE 4096 +#define MAX_SAFE_FFT_SIZE 1024 #define MAX_RADER_FFT_SIZE 2048 #define MAX_BLUESTEIN_FFT_SIZE 2048 // Threadgroup memory batching improves throughput for small n @@ -121,7 +122,7 @@ FFTPlan plan_fft(int n) { int remaining_n = n; // Four Step FFT when N is too large for shared mem. - if (n > MAX_STOCKHAM_FFT_SIZE && is_power_of_2(n)) { + if (n > MAX_SAFE_FFT_SIZE && is_power_of_2(n)) { // For power's of two we have a fast, no transpose four step implementation. plan.four_step = true; // Rough heuristic for choosing faster powers of two when we can diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py index 07ab626722..641b7989d9 100644 --- a/python/tests/test_fft.py +++ b/python/tests/test_fft.py @@ -318,6 +318,14 @@ def g(x): dgdx = torch.func.grad(g)(torch.tensor(x)) self.assertLess((dfdx - dgdx).abs().max() / dgdx.abs().mean(), 1e-4) + def test_fft_large_powers_of_two(self): + for power in [21, 22, 23]: + size = 2**power + x = mx.ones(size, dtype=mx.complex64) + result = mx.fft.fft(x, stream=mx.gpu) + mx.eval(result) + self.assertEqual(result.shape[0], size) + if __name__ == "__main__": mlx_tests.MLXTestRunner()