Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/export.cpp
${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp
${CMAKE_CURRENT_SOURCE_DIR}/failure.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
Expand Down
7 changes: 6 additions & 1 deletion mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <unordered_map>

#include "mlx/array.h"
#include "mlx/failure.h"
#include "mlx/ops.h"
#include "mlx/primitives.h"
#include "mlx/transforms.h"
Expand Down Expand Up @@ -132,7 +133,7 @@ bool array::is_available() const {
if (status() == Status::available) {
return true;
} else if (
status() == Status::evaluated &&
status() == Status::evaluated && !global_failure() &&
(!event().valid() || event().is_signaled())) {
set_status(Status::available);
return true;
Expand All @@ -146,6 +147,10 @@ void array::wait() {
event().wait();
detach_event();
}
if (global_failure()) {
reset_global_failure();
throw std::out_of_range("Array index out of bounds");
}
set_status(Status::available);
}
}
Expand Down
10 changes: 10 additions & 0 deletions mlx/backend/gpu/failure.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright © 2026 Apple Inc.

#pragma once

namespace mlx::core::gpu {

void reset_failure();
bool has_failure();

} // namespace mlx::core::gpu
1 change: 1 addition & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/failure.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
Expand Down
6 changes: 6 additions & 0 deletions mlx/backend/metal/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "mlx/backend/gpu/eval.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/failure.h"
#include "mlx/primitives.h"
#include "mlx/scheduler.h"

Expand All @@ -25,12 +26,17 @@ inline void check_error(MTL::CommandBuffer* cbuf) {
}

void eval(array& arr) {
if (global_failure()) {
return;
}

auto pool = metal::new_scoped_memory_pool();
auto s = arr.primitive().stream();
auto& d = metal::device(s.device);
auto command_buffer = d.get_command_buffer(s.index);

auto outputs = arr.outputs();

{
// If the array is a tracer hold a reference
// to its inputs so they don't get donated
Expand Down
72 changes: 72 additions & 0 deletions mlx/backend/metal/failure.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright © 2026 Apple Inc.

#include <atomic>
#include <cstdint>

#include "mlx/allocator.h"
#include "mlx/backend/gpu/failure.h"
#include "mlx/backend/metal/failure.h"
#include "mlx/failure.h"

namespace mlx::core {

namespace {

class Failure {
public:
static Failure& get() {
static Failure instance;
return instance;
}

Failure(const Failure&) = delete;
Failure& operator=(const Failure&) = delete;

allocator::Buffer& buffer() {
return buffer_;
}

void reset() {
atomic_ptr()->store(FailureCode::NoFailure, std::memory_order_relaxed);
}

FailureCode value() {
return atomic_ptr()->load(std::memory_order_relaxed);
}

private:
Failure() : buffer_(allocator::malloc(sizeof(int32_t))) {
reset();
}
~Failure() = default;

std::atomic<FailureCode>* atomic_ptr() {
return reinterpret_cast<std::atomic<FailureCode>*>(buffer_.raw_ptr());
}

allocator::Buffer buffer_;
};

} // namespace

namespace gpu {

void reset_failure() {
Failure::get().reset();
}

bool has_failure() {
return Failure::get().value() != FailureCode::NoFailure;
}

} // namespace gpu

namespace metal {

MTL::Buffer* get_failure_buffer() {
return static_cast<MTL::Buffer*>(Failure::get().buffer().ptr());
}

} // namespace metal

} // namespace mlx::core
11 changes: 11 additions & 0 deletions mlx/backend/metal/failure.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright © 2026 Apple Inc.

#pragma once

#include "Metal/MTLBuffer.hpp"

namespace mlx::core::metal {

MTL::Buffer* get_failure_buffer();

} // namespace mlx::core::metal
55 changes: 45 additions & 10 deletions mlx/backend/metal/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
#include "mlx/backend/common/utils.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/failure.h"
#include "mlx/backend/metal/jit/includes.h"
#include "mlx/backend/metal/jit/indexing.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/scan.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/dtype.h"
#include "mlx/failure.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"

Expand Down Expand Up @@ -67,6 +69,8 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {

std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";

auto* global_failure = metal::get_failure_buffer();

if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 &&
inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) {
int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1;
Expand All @@ -80,7 +84,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
std::string lib_name = kernel_name;

auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
std::string kernel_source = fmt::format(
"#define BOUNDS_FAILURE {}\n",
static_cast<int>(FailureCode::BoundsFailure));
kernel_source += metal::utils();
kernel_source += metal::gather_front();
kernel_source += get_template_definition(
kernel_name,
Expand All @@ -107,6 +114,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_output_array(out, 2);
compute_encoder.set_bytes(slice_size, 3);
compute_encoder.set_bytes(src.shape(0), 4);
compute_encoder.set_buffer(global_failure, 5, 0);
compute_encoder.dispatch_threads(grid_dims, group_dims);

return;
Expand All @@ -125,7 +133,10 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
std::string lib_name = kernel_name;

auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
std::string kernel_source = fmt::format(
"#define BOUNDS_FAILURE {}\n",
static_cast<int>(FailureCode::BoundsFailure));
kernel_source += metal::utils();
kernel_source += metal::gather();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str =
Expand Down Expand Up @@ -193,14 +204,17 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_vector_bytes(slice_sizes_, 5);
compute_encoder.set_vector_bytes(axes_, 6);

// Set failure buffer
compute_encoder.set_buffer(global_failure, 7, 0);

// Set index info
//
// We don't need to check for empty idx_shapes because gather has a
// idx_ndim == 0 specialization
compute_encoder.set_vector_bytes(idx_shapes, 7);
compute_encoder.set_vector_bytes(idx_strides, 8);
compute_encoder.set_vector_bytes(idx_contigs, 9);
compute_encoder.set_bytes(idx_ndim, 10);
compute_encoder.set_vector_bytes(idx_shapes, 8);
compute_encoder.set_vector_bytes(idx_strides, 9);
compute_encoder.set_vector_bytes(idx_contigs, 10);
compute_encoder.set_bytes(idx_ndim, 11);

// Set index buffers
for (int i = 0; i < nidx; ++i) {
Expand Down Expand Up @@ -301,7 +315,10 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
std::string lib_name = kernel_name;

auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
std::string kernel_source = fmt::format(
"#define BOUNDS_FAILURE {}\n",
static_cast<int>(FailureCode::BoundsFailure));
kernel_source += metal::utils();
concatenate(kernel_source, metal::reduce_utils(), metal::scatter());

std::string out_type_str = get_type_string(out.dtype());
Expand Down Expand Up @@ -352,6 +369,8 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {

compute_encoder.set_compute_pipeline_state(kernel);

auto* global_failure = metal::get_failure_buffer();

// Set all the buffers
compute_encoder.set_input_array(upd, 1);
compute_encoder.set_output_array(out, 2);
Expand Down Expand Up @@ -421,6 +440,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_vector_bytes(idx_contigs, 13);
compute_encoder.set_bytes(idx_ndim, 14);
compute_encoder.set_bytes(idx_size, 15);
compute_encoder.set_buffer(global_failure, 16, 0);

// Set index buffers
for (int i = 0; i < nidx; ++i) {
Expand Down Expand Up @@ -465,7 +485,10 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_name += idx.flags().row_contiguous ? "c" : "nc";

auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
std::string kernel_source = fmt::format(
"#define BOUNDS_FAILURE {}\n",
static_cast<int>(FailureCode::BoundsFailure));
kernel_source += metal::utils();
kernel_source += metal::gather_axis();
std::string out_type_str = get_type_string(out.dtype());
std::string idx_type_str = get_type_string(idx.dtype());
Expand Down Expand Up @@ -517,6 +540,9 @@ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_bytes(src.strides(axis_), 9);
compute_encoder.set_bytes(idx.strides(axis_), 10);

auto* global_failure = metal::get_failure_buffer();
compute_encoder.set_buffer(global_failure, 11, 0);

compute_encoder.dispatch_threads(grid_dims, group_dims);
}

Expand Down Expand Up @@ -569,7 +595,10 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
kernel_name += idx.flags().row_contiguous ? "c" : "nc";

auto lib = d.get_library(lib_name, [&]() {
std::string kernel_source = metal::utils();
std::string kernel_source = fmt::format(
"#define BOUNDS_FAILURE {}\n",
static_cast<int>(FailureCode::BoundsFailure));
kernel_source += metal::utils();
kernel_source += metal::reduce_utils();
kernel_source += metal::scatter_axis();
std::string out_type_str = get_type_string(out.dtype());
Expand Down Expand Up @@ -641,6 +670,9 @@ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.set_bytes(upd.strides(axis_), 9);
compute_encoder.set_bytes(idx.strides(axis_), 10);

auto* global_failure = metal::get_failure_buffer();
compute_encoder.set_buffer(global_failure, 11, 0);

compute_encoder.dispatch_threads(grid_dims, group_dims);
}

Expand Down Expand Up @@ -695,7 +727,10 @@ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
fmt::format("{}_{}_{}", kBaseName, dtype_tag, contiguous);

auto lib = d.get_library(kernel_name, [&]() {
std::string source = metal::utils();
std::string source = fmt::format(
"#define BOUNDS_FAILURE {}\n",
static_cast<int>(FailureCode::BoundsFailure));
source += metal::utils();
source += metal::masked_scatter();
source +=
fmt::format(masked_assign_kernel, kernel_name, value_type, contiguous);
Expand Down
12 changes: 8 additions & 4 deletions mlx/backend/metal/jit/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ constexpr std::string_view gather_kernels = R"(
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
const constant int* idx_shapes [[buffer(7)]],
const constant int64_t* idx_strides [[buffer(8)]],
const constant bool* idx_contigs [[buffer(9)]],
const constant int& idx_ndim [[buffer(10)]],
device atomic<int32_t>* global_failure [[buffer(7)]],
const constant int* idx_shapes [[buffer(8)]],
const constant int64_t* idx_strides [[buffer(9)]],
const constant bool* idx_contigs [[buffer(10)]],
const constant int& idx_ndim [[buffer(11)]],
{4}
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {{
Expand All @@ -27,6 +28,7 @@ constexpr std::string_view gather_kernels = R"(
src_ndim,
slice_sizes,
axes,
global_failure,
idxs,
index,
grid_dim);
Expand All @@ -50,6 +52,7 @@ constexpr std::string_view scatter_kernels = R"(
const constant bool* idx_contigs [[buffer(13)]],
const constant int& idx_ndim [[buffer(14)]],
const constant size_t& idx_size [[buffer(15)]],
device atomic<int32_t>* global_failure [[buffer(16)]],
{5}
uint2 gid [[thread_position_in_grid]]) {{
Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}};
Expand All @@ -66,6 +69,7 @@ constexpr std::string_view scatter_kernels = R"(
out_ndim,
axes,
idx_size,
global_failure,
idxs,
gid);
}}
Expand Down
6 changes: 6 additions & 0 deletions mlx/backend/metal/kernels/indexing/gather.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ METAL_FUNC void gather_impl(
const constant size_t& src_ndim [[buffer(4)]],
const constant int* slice_sizes [[buffer(5)]],
const constant int* axes [[buffer(6)]],
device atomic<int32_t>* global_failure [[buffer(7)]],
const thread Indices<IdxT, NIDX>& indices,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
Expand All @@ -35,6 +36,11 @@ METAL_FUNC void gather_impl(
}
auto ax = axes[i];
auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]);

if (!check_bounds(idx_val, src_shape[ax], global_failure)) {
return;
}

src_idx += static_cast<LocT>(idx_val) * static_cast<LocT>(src_strides[ax]);
}

Expand Down
Loading