Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 157 additions & 82 deletions mlx/backend/metal/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,37 +31,17 @@ void explicit_gemm_conv_ND_gpu(
int implicit_M = out.size() / conv_params.O;
int implicit_K = wt.size() / conv_params.O;
int implicit_N = conv_params.O;
// Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});

in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));

// Prepare unfolding kernel
std::string kname;
kname.reserve(32);
concatenate(kname, "naive_unfold_nd_", type_to_name(in_unfolded), "_", N);
concatenate(kname, "naive_unfold_nd_", type_to_name(in), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);

compute_encoder.set_bytes(conv_params, 2);

// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;

MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);

compute_encoder.dispatch_threads(grid_dims, group_dims);

// Reshape weight
Shape wt_reshape{implicit_K, implicit_N};
Strides wt_restride{1, implicit_K};
Expand All @@ -71,23 +51,80 @@ void explicit_gemm_conv_ND_gpu(
wt_flags.col_contiguous = true;
wt_reshaped.copy_shared_buffer(wt, wt_restride, wt_flags, wt.data_size());

// Perform gemm
std::vector<array> copies = {in_unfolded};
return steel_matmul(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_reshaped,
/*c = */ out,
/*M = */ implicit_M,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*batch_size_out = */ 1,
/*a_cols = */ implicit_K,
/*b_cols = */ implicit_K,
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
// Prepare output 2D view for tiling
Strides out_2d_strides{
out.strides()[out.ndim() - 2], out.strides()[out.ndim() - 1]};
array out_2d({implicit_M, implicit_N}, out.dtype(), nullptr, {});
out_2d.copy_shared_buffer(out, out_2d_strides, out.flags(), out.data_size());

size_t itemsize = in.itemsize();
size_t max_buf = d.mtl_device()->maxBufferLength();
size_t row_bytes = static_cast<size_t>(implicit_K) * itemsize;
size_t max_rows = row_bytes == 0 ? 0 : (max_buf / row_bytes);

if (max_rows == 0) {
std::ostringstream msg;
msg << "[conv] explicit GEMM requires an unfolding buffer row of "
<< row_bytes << " bytes which exceeds Metal maxBufferLength "
<< max_buf << " bytes.";
throw std::runtime_error(msg.str());
}

auto run_tile = [&](int row_offset, int tile_rows) {
// Prepare unfolding array
Shape unfolded_shape{tile_rows, implicit_K};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));

compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(row_offset, 3);

// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;

MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);

compute_encoder.dispatch_threads(grid_dims, group_dims);

array out_tile({tile_rows, implicit_N}, out.dtype(), nullptr, {});
int64_t offset = static_cast<int64_t>(row_offset) * out_2d_strides[0];
out_tile.copy_shared_buffer(
out_2d, out_2d_strides, out_2d.flags(), out_2d.data_size(), offset);

// Perform gemm
std::vector<array> copies = {in_unfolded};
return steel_matmul(
s,
d,
/*a = */ in_unfolded,
/*b = */ wt_reshaped,
/*c = */ out_tile,
/*M = */ tile_rows,
/*N = */ implicit_N,
/*K = */ implicit_K,
/*batch_size_out = */ 1,
/*a_cols = */ implicit_K,
/*b_cols = */ implicit_K,
/*a_transposed = */ false,
/*b_transposed = */ true,
/*copies = */ copies);
};

if (static_cast<size_t>(implicit_M) <= max_rows) {
return run_tile(0, implicit_M);
}

for (int row_offset = 0; row_offset < implicit_M;
row_offset += static_cast<int>(max_rows)) {
int tile_rows =
std::min<int>(implicit_M - row_offset, static_cast<int>(max_rows));
run_tile(row_offset, tile_rows);
}
}

template <int N>
Expand All @@ -111,37 +148,17 @@ void explicit_gemm_conv_group_ND_gpu(
kernel_size *= conv_params.wS[i];
}

// Prepare unfolding array
Shape unfolded_shape{implicit_M, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));

// Prepare unfolding kernel
std::string kname;
kname.reserve(32);
concatenate(
kname, "naive_unfold_transpose_nd_", type_to_name(in_unfolded), "_", N);
kname, "naive_unfold_transpose_nd_", type_to_name(in), "_", N);
auto& compute_encoder = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
compute_encoder.set_compute_pipeline_state(kernel);

compute_encoder.set_input_array(in, 0);
compute_encoder.set_output_array(in_unfolded, 1);

compute_encoder.set_bytes(conv_params, 2);

// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;

MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);

compute_encoder.dispatch_threads(grid_dims, group_dims);

// Transpose kernel weights so that we can slice them by contiguous chunks
// of channel groups.
array wt_view(
Expand All @@ -152,29 +169,87 @@ void explicit_gemm_conv_group_ND_gpu(
// Materialize
array wt_transpose = contiguous_copy_gpu(wt_view, s);

// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
/* const Stream& s = */ s,
/* Device& d = */ d,
/* const array& a = */ in_unfolded,
/* const array& b = */ wt_transpose,
/* array& c = */ out,
/* int M = */ implicit_M,
/* int N = */ implicit_N,
/* int K = */ implicit_K,
/* int batch_size_out = */ groups,
/* int lda = */ implicit_K * groups,
/* int ldb = */ implicit_K,
/* int ldd = */ implicit_N * groups,
/* bool transpose_a = */ false,
/* bool transpose_b = */ true,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ {1},
/* Strides batch_strides = */ {0},
/* int64_t A_batch_strides = */ int64_t(implicit_K),
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
// Prepare output 2D view for tiling
Strides out_2d_strides{
out.strides()[out.ndim() - 2], out.strides()[out.ndim() - 1]};
array out_2d({implicit_M, conv_params.O}, out.dtype(), nullptr, {});
out_2d.copy_shared_buffer(out, out_2d_strides, out.flags(), out.data_size());

size_t itemsize = in.itemsize();
size_t max_buf = d.mtl_device()->maxBufferLength();
size_t row_bytes =
static_cast<size_t>(implicit_K) * static_cast<size_t>(groups) * itemsize;
size_t max_rows = row_bytes == 0 ? 0 : (max_buf / row_bytes);

if (max_rows == 0) {
std::ostringstream msg;
msg << "[conv] explicit GEMM requires an unfolding buffer row of "
<< row_bytes << " bytes which exceeds Metal maxBufferLength "
<< max_buf << " bytes.";
throw std::runtime_error(msg.str());
}

auto run_tile = [&](int row_offset, int tile_rows) {
// Prepare unfolding array
Shape unfolded_shape{tile_rows, implicit_K * groups};
array in_unfolded(unfolded_shape, in.dtype(), nullptr, {});
in_unfolded.set_data(allocator::malloc(in_unfolded.nbytes()));

compute_encoder.set_output_array(in_unfolded, 1);
compute_encoder.set_bytes(row_offset, 3);

// Launch unfolding kernel
size_t tgp_x = std::min(conv_params.C, 64);
tgp_x = 32 * ((tgp_x + 32 - 1) / 32);
size_t tgp_y = 256 / tgp_x;

MTL::Size grid_dims = MTL::Size(
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
MTL::Size group_dims = MTL::Size(
std::min(tgp_x, grid_dims.width), std::min(tgp_y, grid_dims.height), 1);

compute_encoder.dispatch_threads(grid_dims, group_dims);

array out_tile({tile_rows, conv_params.O}, out.dtype(), nullptr, {});
int64_t offset = static_cast<int64_t>(row_offset) * out_2d_strides[0];
out_tile.copy_shared_buffer(
out_2d, out_2d_strides, out_2d.flags(), out_2d.data_size(), offset);

// Perform gemm
std::vector<array> copies = {in_unfolded, wt_transpose};
return steel_matmul_regular(
/* const Stream& s = */ s,
/* Device& d = */ d,
/* const array& a = */ in_unfolded,
/* const array& b = */ wt_transpose,
/* array& c = */ out_tile,
/* int M = */ tile_rows,
/* int N = */ implicit_N,
/* int K = */ implicit_K,
/* int batch_size_out = */ groups,
/* int lda = */ implicit_K * groups,
/* int ldb = */ implicit_K,
/* int ldd = */ implicit_N * groups,
/* bool transpose_a = */ false,
/* bool transpose_b = */ true,
/* std::vector<array>& copies = */ copies,
/* Shape batch_shape = */ {1},
/* Strides batch_strides = */ {0},
/* int64_t A_batch_strides = */ int64_t(implicit_K),
/* int64_t B_batch_strides = */ int64_t(implicit_N) * implicit_K,
/* int64_t matrix_stride_out = */ int64_t(implicit_N));
};

if (static_cast<size_t>(implicit_M) <= max_rows) {
return run_tile(0, implicit_M);
}

for (int row_offset = 0; row_offset < implicit_M;
row_offset += static_cast<int>(max_rows)) {
int tile_rows =
std::min<int>(implicit_M - row_offset, static_cast<int>(max_rows));
run_tile(row_offset, tile_rows);
}
}

void implicit_gemm_conv_2D_gpu(
Expand Down
16 changes: 12 additions & 4 deletions mlx/backend/metal/kernels/conv.metal
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ template <typename T, int N>
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
const constant int* row_offset_ptr [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for (short i = 0; i < N; i++)
Expand All @@ -39,8 +40,10 @@ template <typename T, int N>
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)

int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int row_offset = *row_offset_ptr;
int global_row = row_offset + int(gid.z);
int n = (global_row) / out_pixels;
int oS = (global_row) % out_pixels;
int wS = gid.y;

bool valid = n < params->N;
Expand Down Expand Up @@ -83,6 +86,7 @@ template <typename T, int N>
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
const constant MLXConvParams<N>* params [[buffer(2)]],
const constant int* row_offset_ptr [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
int filter_size = params->C;
for (short i = 0; i < N; i++)
Expand All @@ -102,8 +106,10 @@ template <typename T, int N>
// gid.y: wS (Filter location to unfold input)
// gid.x: C (channel)

int n = (gid.z) / out_pixels;
int oS = (gid.z) % out_pixels;
int row_offset = *row_offset_ptr;
int global_row = row_offset + int(gid.z);
int n = (global_row) / out_pixels;
int oS = (global_row) % out_pixels;
int wS = gid.y;

bool valid = n < params->N;
Expand Down Expand Up @@ -149,13 +155,15 @@ template <typename T, int N>
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
const constant int* row_offset_ptr [[buffer(3)]], \
uint3 gid [[thread_position_in_grid]]); \
template \
[[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \
naive_unfold_transpose_Nd( \
const device itype* in [[buffer(0)]], \
device itype* out [[buffer(1)]], \
const constant MLXConvParams<n>* params [[buffer(2)]], \
const constant int* row_offset_ptr [[buffer(3)]], \
uint3 gid [[thread_position_in_grid]]);

#define instantiate_naive_unfold_nd_dims(name, itype) \
Expand Down