From 0d5f64f49bcf7159d639db467e0d7b1b9dfc5fa8 Mon Sep 17 00:00:00 2001 From: itsmedonttell <129792439+itsmedonttell@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:24:57 -0500 Subject: [PATCH] metal/conv: tile explicit GEMM unfold to avoid maxBufferLength --- mlx/backend/metal/conv.cpp | 239 ++++++++++++++++++--------- mlx/backend/metal/kernels/conv.metal | 16 +- 2 files changed, 169 insertions(+), 86 deletions(-) diff --git a/mlx/backend/metal/conv.cpp b/mlx/backend/metal/conv.cpp index b4a674ff0e..40abe93731 100644 --- a/mlx/backend/metal/conv.cpp +++ b/mlx/backend/metal/conv.cpp @@ -31,37 +31,17 @@ void explicit_gemm_conv_ND_gpu( int implicit_M = out.size() / conv_params.O; int implicit_K = wt.size() / conv_params.O; int implicit_N = conv_params.O; - // Prepare unfolding array - Shape unfolded_shape{implicit_M, implicit_K}; - array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); - - in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); // Prepare unfolding kernel std::string kname; kname.reserve(32); - concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N); + concatenate(kname, "naive_unfold_nd_", type_to_name(in), "_", N); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(in_unfolded, 1); - compute_encoder.set_bytes(conv_params, 2); - // Launch unfolding kernel - size_t tgp_x = std::min(conv_params.C, 64); - tgp_x = 32 * ((tgp_x + 32 - 1) / 32); - size_t tgp_y = 256 / tgp_x; - - MTL::Size grid_dims = MTL::Size( - conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); - MTL::Size group_dims = MTL::Size( - std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); - - compute_encoder.dispatch_threads(grid_dims, group_dims); - // Reshape weight Shape wt_reshape{implicit_K, implicit_N}; Strides wt_restride{1, implicit_K}; @@ -71,23 +51,80 @@ void explicit_gemm_conv_ND_gpu( wt_flags.col_contiguous = true; wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size()); - // Perform gemm - std::vector copies = {in_unfolded}; - return steel_matmul( - s, - d, - /*a = */ in_unfolded, - /*b = */ wt_reshaped, - /*c = */ out, - /*M = */ implicit_M, - /*N = */ implicit_N, - /*K = */ implicit_K, - /*batch_size_out = */ 1, - /*a_cols = */ implicit_K, - /*b_cols = */ implicit_K, - /*a_transposed = */ false, - /*b_transposed = */ true, - /*copies = */ copies); + // Prepare output 2D view for tiling + Strides out_2d_strides{ + out.strides()[out.ndim() - 2], out.strides()[out.ndim() - 1]}; + array out_2d({implicit_M, implicit_N}, out.dtype(), nullptr, {}); + out_2d.copy_shared_buffer(out, out_2d_strides, out.flags(), out.data_size()); + + size_t itemsize = in.itemsize(); + size_t max_buf = d.mtl_device()->maxBufferLength(); + size_t row_bytes = static_cast(implicit_K) * itemsize; + size_t max_rows = row_bytes == 0 ? 0 : (max_buf / row_bytes); + + if (max_rows == 0) { + std::ostringstream msg; + msg << "[conv] explicit GEMM requires an unfolding buffer row of " + << row_bytes << " bytes which exceeds Metal maxBufferLength " + << max_buf << " bytes."; + throw std::runtime_error(msg.str()); + } + + auto run_tile = [&](int row_offset, int tile_rows) { + // Prepare unfolding array + Shape unfolded_shape{tile_rows, implicit_K}; + array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); + in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); + + compute_encoder.set_output_array(in_unfolded, 1); + compute_encoder.set_bytes(row_offset, 3); + + // Launch unfolding kernel + size_t tgp_x = std::min(conv_params.C, 64); + tgp_x = 32 * ((tgp_x + 32 - 1) / 32); + size_t tgp_y = 256 / tgp_x; + + MTL::Size grid_dims = MTL::Size( + conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); + MTL::Size group_dims = MTL::Size( + std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); + + compute_encoder.dispatch_threads(grid_dims, group_dims); + + array out_tile({tile_rows, implicit_N}, out.dtype(), nullptr, {}); + int64_t offset = static_cast(row_offset) * out_2d_strides[0]; + out_tile.copy_shared_buffer( + out_2d, out_2d_strides, out_2d.flags(), out_2d.data_size(), offset); + + // Perform gemm + std::vector copies = {in_unfolded}; + return steel_matmul( + s, + d, + /*a = */ in_unfolded, + /*b = */ wt_reshaped, + /*c = */ out_tile, + /*M = */ tile_rows, + /*N = */ implicit_N, + /*K = */ implicit_K, + /*batch_size_out = */ 1, + /*a_cols = */ implicit_K, + /*b_cols = */ implicit_K, + /*a_transposed = */ false, + /*b_transposed = */ true, + /*copies = */ copies); + }; + + if (static_cast(implicit_M) <= max_rows) { + return run_tile(0, implicit_M); + } + + for (int row_offset = 0; row_offset < implicit_M; + row_offset += static_cast(max_rows)) { + int tile_rows = + std::min(implicit_M - row_offset, static_cast(max_rows)); + run_tile(row_offset, tile_rows); + } } template @@ -111,37 +148,17 @@ void explicit_gemm_conv_group_ND_gpu( kernel_size *= conv_params.wS[i]; } - // Prepare unfolding array - Shape unfolded_shape{implicit_M, implicit_K * groups}; - array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); - in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); - // Prepare unfolding kernel std::string kname; kname.reserve(32); concatenate( - kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N); + kname, "naive_unfold_transpose_nd_", type_to_name(in), "_", N); auto& compute_encoder = d.get_command_encoder(s.index); auto kernel = d.get_kernel(kname); compute_encoder.set_compute_pipeline_state(kernel); - compute_encoder.set_input_array(in, 0); - compute_encoder.set_output_array(in_unfolded, 1); - compute_encoder.set_bytes(conv_params, 2); - // Launch unfolding kernel - size_t tgp_x = std::min(conv_params.C, 64); - tgp_x = 32 * ((tgp_x + 32 - 1) / 32); - size_t tgp_y = 256 / tgp_x; - - MTL::Size grid_dims = MTL::Size( - conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); - MTL::Size group_dims = MTL::Size( - std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); - - compute_encoder.dispatch_threads(grid_dims, group_dims); - // Transpose kernel weights so that we can slice them by contiguous chunks // of channel groups. array wt_view( @@ -152,29 +169,87 @@ void explicit_gemm_conv_group_ND_gpu( // Materialize array wt_transpose = contiguous_copy_gpu(wt_view, s); - // Perform gemm - std::vector copies = {in_unfolded, wt_transpose}; - return steel_matmul_regular( - /* const Stream& s = */ s, - /* Device& d = */ d, - /* const array& a = */ in_unfolded, - /* const array& b = */ wt_transpose, - /* array& c = */ out, - /* int M = */ implicit_M, - /* int N = */ implicit_N, - /* int K = */ implicit_K, - /* int batch_size_out = */ groups, - /* int lda = */ implicit_K * groups, - /* int ldb = */ implicit_K, - /* int ldd = */ implicit_N * groups, - /* bool transpose_a = */ false, - /* bool transpose_b = */ true, - /* std::vector& copies = */ copies, - /* Shape batch_shape = */ {1}, - /* Strides batch_strides = */ {0}, - /* int64_t A_batch_strides = */ int64_t(implicit_K), - /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, - /* int64_t matrix_stride_out = */ int64_t(implicit_N)); + // Prepare output 2D view for tiling + Strides out_2d_strides{ + out.strides()[out.ndim() - 2], out.strides()[out.ndim() - 1]}; + array out_2d({implicit_M, conv_params.O}, out.dtype(), nullptr, {}); + out_2d.copy_shared_buffer(out, out_2d_strides, out.flags(), out.data_size()); + + size_t itemsize = in.itemsize(); + size_t max_buf = d.mtl_device()->maxBufferLength(); + size_t row_bytes = + static_cast(implicit_K) * static_cast(groups) * itemsize; + size_t max_rows = row_bytes == 0 ? 0 : (max_buf / row_bytes); + + if (max_rows == 0) { + std::ostringstream msg; + msg << "[conv] explicit GEMM requires an unfolding buffer row of " + << row_bytes << " bytes which exceeds Metal maxBufferLength " + << max_buf << " bytes."; + throw std::runtime_error(msg.str()); + } + + auto run_tile = [&](int row_offset, int tile_rows) { + // Prepare unfolding array + Shape unfolded_shape{tile_rows, implicit_K * groups}; + array in_unfolded(unfolded_shape, in.dtype(), nullptr, {}); + in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes())); + + compute_encoder.set_output_array(in_unfolded, 1); + compute_encoder.set_bytes(row_offset, 3); + + // Launch unfolding kernel + size_t tgp_x = std::min(conv_params.C, 64); + tgp_x = 32 * ((tgp_x + 32 - 1) / 32); + size_t tgp_y = 256 / tgp_x; + + MTL::Size grid_dims = MTL::Size( + conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]); + MTL::Size group_dims = MTL::Size( + std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1); + + compute_encoder.dispatch_threads(grid_dims, group_dims); + + array out_tile({tile_rows, conv_params.O}, out.dtype(), nullptr, {}); + int64_t offset = static_cast(row_offset) * out_2d_strides[0]; + out_tile.copy_shared_buffer( + out_2d, out_2d_strides, out_2d.flags(), out_2d.data_size(), offset); + + // Perform gemm + std::vector copies = {in_unfolded, wt_transpose}; + return steel_matmul_regular( + /* const Stream& s = */ s, + /* Device& d = */ d, + /* const array& a = */ in_unfolded, + /* const array& b = */ wt_transpose, + /* array& c = */ out_tile, + /* int M = */ tile_rows, + /* int N = */ implicit_N, + /* int K = */ implicit_K, + /* int batch_size_out = */ groups, + /* int lda = */ implicit_K * groups, + /* int ldb = */ implicit_K, + /* int ldd = */ implicit_N * groups, + /* bool transpose_a = */ false, + /* bool transpose_b = */ true, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ {1}, + /* Strides batch_strides = */ {0}, + /* int64_t A_batch_strides = */ int64_t(implicit_K), + /* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K, + /* int64_t matrix_stride_out = */ int64_t(implicit_N)); + }; + + if (static_cast(implicit_M) <= max_rows) { + return run_tile(0, implicit_M); + } + + for (int row_offset = 0; row_offset < implicit_M; + row_offset += static_cast(max_rows)) { + int tile_rows = + std::min(implicit_M - row_offset, static_cast(max_rows)); + run_tile(row_offset, tile_rows); + } } void implicit_gemm_conv_2D_gpu( diff --git a/mlx/backend/metal/kernels/conv.metal b/mlx/backend/metal/kernels/conv.metal index 50a14ad3eb..244e211e7d 100644 --- a/mlx/backend/metal/kernels/conv.metal +++ b/mlx/backend/metal/kernels/conv.metal @@ -20,6 +20,7 @@ template const device T* in [[buffer(0)]], device T* out [[buffer(1)]], const constant MLXConvParams* params [[buffer(2)]], + const constant int* row_offset_ptr [[buffer(3)]], uint3 gid [[thread_position_in_grid]]) { int filter_size = params->C; for (short i = 0; i < N; i++) @@ -39,8 +40,10 @@ template // gid.y: wS (Filter location to unfold input) // gid.x: C (channel) - int n = (gid.z) / out_pixels; - int oS = (gid.z) % out_pixels; + int row_offset = *row_offset_ptr; + int global_row = row_offset + int(gid.z); + int n = (global_row) / out_pixels; + int oS = (global_row) % out_pixels; int wS = gid.y; bool valid = n < params->N; @@ -83,6 +86,7 @@ template const device T* in [[buffer(0)]], device T* out [[buffer(1)]], const constant MLXConvParams* params [[buffer(2)]], + const constant int* row_offset_ptr [[buffer(3)]], uint3 gid [[thread_position_in_grid]]) { int filter_size = params->C; for (short i = 0; i < N; i++) @@ -102,8 +106,10 @@ template // gid.y: wS (Filter location to unfold input) // gid.x: C (channel) - int n = (gid.z) / out_pixels; - int oS = (gid.z) % out_pixels; + int row_offset = *row_offset_ptr; + int global_row = row_offset + int(gid.z); + int n = (global_row) / out_pixels; + int oS = (global_row) % out_pixels; int wS = gid.y; bool valid = n < params->N; @@ -149,6 +155,7 @@ template const device itype* in [[buffer(0)]], \ device itype* out [[buffer(1)]], \ const constant MLXConvParams* params [[buffer(2)]], \ + const constant int* row_offset_ptr [[buffer(3)]], \ uint3 gid [[thread_position_in_grid]]); \ template \ [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \ @@ -156,6 +163,7 @@ template const device itype* in [[buffer(0)]], \ device itype* out [[buffer(1)]], \ const constant MLXConvParams* params [[buffer(2)]], \ + const constant int* row_offset_ptr [[buffer(3)]], \ uint3 gid [[thread_position_in_grid]]); #define instantiate_naive_unfold_nd_dims(name, itype) \