-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Describe the bug
On Apple Silicon (Metal backend), conv_transpose1d/2d/3d with stride > 1 maps stride -> input_dilation and routes to the explicit GEMM path, which allocates a full im2col/unfold buffer. For large outputs this exceeds Metal’s maxBufferLength and fails with a Metal allocation error. This affects transposed convs (and any use of conv_general with input_dilation > 1) on Metal.
To Reproduce
import mlx.core as mx
print("mlx", mx.__version__)
print("device_info", mx.device_info())
def try_n(n_frames, kernel_size=4096, stride=1024):
x = mx.ones((1, n_frames, 1), dtype=mx.float32)
w = mx.ones((1, kernel_size, 1), dtype=mx.float32)
y = mx.conv_transpose1d(x, w, stride=stride, padding=0)
mx.eval(y)
return y.shape
for n in [5150, 5155, 5156, 5157, 5158, 5160]:
try:
out = try_n(n)
print("OK", n, out)
except Exception as e:
print("FAIL", n, e)Observed output (M4 Max, mlx 0.30.4):
mlx 0.30.4
device_info {'device_name': 'Apple M4 Max', 'max_recommended_working_set_size': 115448725504, 'memory_size': 137438953472, 'architecture': 'applegpu_g16s', 'max_buffer_length': 86586540032, 'resource_limit': 499000}
OK 5150 (1, 5276672, 1)
OK 5155 (1, 5281792, 1)
OK 5156 (1, 5282816, 1)
OK 5157 (1, 5283840, 1)
FAIL 5158 [metal::malloc] Attempting to allocate 86587211776 bytes which is greater than the maximum allowed buffer size of 86586540032 bytes.
FAIL 5160 [metal::malloc] Attempting to allocate 86620766208 bytes which is greater than the maximum allowed buffer size of 86586540032 bytes.
Expected behavior
The operation should complete successfully when the output size fits in memory, or MLX should provide a clear, MLX-level error before Metal allocation fails (explaining the unfolding buffer size requirement).
Desktop (please complete the following information):
- OS Version: macOS 26.2
- Version: mlx 0.30.4
- Device: Apple M4 Max (
max_buffer_length= 86,586,540,032 bytes)
Additional context
Note: This issue persists in both mlx 0.30.3 and 0.30.4.
Root cause (code path):
conv_transpose_generalmapsstride -> input_dilationand callsconv_general(mlx/ops.cpp).- In Metal
conv_1D_gpu, implicit GEMM is only used wheninput_dilation == 1; otherwise it routes to explicit GEMM (mlx/backend/metal/conv.cpp). explicit_gemm_conv_ND_gpuallocates a full unfolding buffer (mlx/backend/metal/conv.cpp).MetalAllocator::mallocthrows when requested buffer exceedsmaxBufferLength(mlx/backend/metal/allocator.cpp).
Exact math for transposed 1D, padding=0, dilation=1:
L_out = (n_frames - 1) * stride + kernel_size
unfold_bytes = L_out * K * 4 (K = kernel_size * C)
out_bytes = L_out * O * 4
ratio = unfold_bytes / out_bytes = K / O
For the repro (C=O=1, kernel_size=4096):
L_out = 10,243,072out_bytes = 40,972,288(~39.1 MiB)unfold_bytes = 167,822,491,648(~156.25 GiB)ratio = 4096x
Scope:
- Affected:
conv_transpose1d,conv_transpose2d,conv_transpose3d, andconv_generalwithinput_dilation > 1on Metal. - Not affected by this path:
conv1d/conv2d/conv3din standard usage (input_dilation = 1).
Suggested fixes:
- Short-term: Add a size check before allocating the unfolding buffer and throw an MLX-level exception with a clear error message explaining the buffer size limitation.
- Long-term: Implement tiled/chunked unfolding or extend implicit GEMM to support
input_dilation > 1.
File hashes (SHA-256) for reference:
mlx/ops.cpp:914974ab5cbd62796b0930cfc2803fd2449c99fd530599e031b78964318d80a0mlx/backend/metal/conv.cpp:eda371524e4670a1d5157c1ed74459a95cf0460d6f641449e940bcff1c25771bmlx/backend/metal/allocator.cpp:883a3e66a62d9c97f3c97bd40e53633e90e3e41fb7eadec3e6a73861f82b8726