Add tiled unfolding to fix Metal maxBufferLength failures in conv_transpose #3084
+169
−86
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Title: Add tiled unfolding to fix Metal maxBufferLength failures in conv_transpose
Summary
Fixes transposed convolution failures on Metal when unfolding buffers exceed
maxBufferLength.Previously,
conv_transpose1d/2d/3dwithstride > 1would allocate full im2col buffers that could exceed Metal'smaxBufferLength, causing allocation failures. This PR implements tiled unfolding to process the operation in chunks that fit within hardware limits.Changes:
implicit_MdimensionImpact
Before: Operations requiring unfold buffers larger than
maxBufferLengthfailed with Metal allocation errorsAfter: Same operations succeed by processing in tiles. Example on M4 Max:
(1, 5158, 1), kernel:(1, 4096, 1), stride: 1024maxBufferLengthof ~86.6 GB) → tiled into chunks withinmaxBufferLengthmaxBufferLengtherrorImplementation Details
explicit_gemm_conv_ND_gpunow tiles along rows and writes directly into row-window views of output (mlx/backend/metal/conv.cpp)explicit_gemm_conv_group_ND_gpuapplies same tiling logic for grouped convolutions (mlx/backend/metal/conv.cpp)naive_unfold_Nd/naive_unfold_transpose_Nd) accept a row offset parameter for correct global indexing (mlx/backend/metal/kernels/conv.metal)maxBufferLength / row_bytesmaxBufferLength(Metal API limit)Performance Note
Tiling introduces additional kernel launches and a temporary buffer per tile, which may slightly reduce performance for very large outputs. However, this tradeoff enables operations that would otherwise fail due to
maxBufferLengthconstraints.Validation
Tested on Apple M4 Max (macOS 26.2, MLX 0.30.5.dev20260129+590b4f1c):
Output:
Files Changed
mlx/backend/metal/conv.cpp- Tiling logic for explicit GEMM convolutionsmlx/backend/metal/kernels/conv.metal- Row-offset parameter in unfold kernelsRelated Issues
Fixes #3082