Skip to content

Conversation

@itsmedonttell
Copy link

Title: Add tiled unfolding to fix Metal maxBufferLength failures in conv_transpose

Summary

Fixes transposed convolution failures on Metal when unfolding buffers exceed maxBufferLength.

Previously, conv_transpose1d/2d/3d with stride > 1 would allocate full im2col buffers that could exceed Metal's maxBufferLength, causing allocation failures. This PR implements tiled unfolding to process the operation in chunks that fit within hardware limits.

Changes:

  • Add row-offset support to Metal unfold kernels
  • Tile explicit GEMM conv unfold + GEMM along implicit_M dimension
  • Add validation guard for edge case where a single unfold row exceeds max buffer size

Impact

Before: Operations requiring unfold buffers larger than maxBufferLength failed with Metal allocation errors

After: Same operations succeed by processing in tiles. Example on M4 Max:

  • Input: (1, 5158, 1), kernel: (1, 4096, 1), stride: 1024
  • Unfold buffer required: ~156 GB (exceeds device's maxBufferLength of ~86.6 GB) → tiled into chunks within maxBufferLength
  • Previously: FAIL with maxBufferLength error
  • Now: SUCCESS

Implementation Details

  • explicit_gemm_conv_ND_gpu now tiles along rows and writes directly into row-window views of output (mlx/backend/metal/conv.cpp)
  • explicit_gemm_conv_group_ND_gpu applies same tiling logic for grouped convolutions (mlx/backend/metal/conv.cpp)
  • Unfold kernels (naive_unfold_Nd / naive_unfold_transpose_Nd) accept a row offset parameter for correct global indexing (mlx/backend/metal/kernels/conv.metal)
  • Tile size is computed from maxBufferLength / row_bytes
  • Returns error if a single unfold row exceeds the device's maxBufferLength (Metal API limit)

Performance Note

Tiling introduces additional kernel launches and a temporary buffer per tile, which may slightly reduce performance for very large outputs. However, this tradeoff enables operations that would otherwise fail due to maxBufferLength constraints.

Validation

Tested on Apple M4 Max (macOS 26.2, MLX 0.30.5.dev20260129+590b4f1c):

import mlx.core as mx

print("mlx", mx.__version__)
print("device_info", mx.device_info())

def try_n(n_frames, kernel_size=4096, stride=1024):
    x = mx.ones((1, n_frames, 1), dtype=mx.float32)
    w = mx.ones((1, kernel_size, 1), dtype=mx.float32)
    y = mx.conv_transpose1d(x, w, stride=stride, padding=0)
    mx.eval(y)
    return y.shape

for n in [5150, 5155, 5156, 5157, 5158, 5160]:
    try:
        out = try_n(n)
        print("OK", n, out)
    except Exception as e:
        print("FAIL", n, e)

Output:

mlx 0.30.5.dev20260129+590b4f1c
device_info {'max_buffer_length': 86586540032, ...}

OK 5150 (1, 5276672, 1)
OK 5155 (1, 5281792, 1)
OK 5156 (1, 5282816, 1)
OK 5157 (1, 5283840, 1)
OK 5158 (1, 5284864, 1)  # Previously FAILED
OK 5160 (1, 5286912, 1)  # Previously FAILED

Files Changed

  • mlx/backend/metal/conv.cpp - Tiling logic for explicit GEMM convolutions
  • mlx/backend/metal/kernels/conv.metal - Row-offset parameter in unfold kernels

Related Issues

Fixes #3082

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] conv_transpose* on Metal exceeds maxBufferLength due to full im2col buffer allocation

1 participant