From 3bc395b3934a21fdb3f75a6b99d55dc3dbd5902a Mon Sep 17 00:00:00 2001 From: Nikolaj Hey Hinnerskov Date: Sun, 1 Feb 2026 23:36:54 +0100 Subject: [PATCH] Add array bounds checking in the Metal backend. --- mlx/CMakeLists.txt | 1 + mlx/array.cpp | 7 +- mlx/backend/gpu/failure.h | 10 ++ mlx/backend/metal/CMakeLists.txt | 1 + mlx/backend/metal/eval.cpp | 6 + mlx/backend/metal/failure.cpp | 72 +++++++++++ mlx/backend/metal/failure.h | 11 ++ mlx/backend/metal/indexing.cpp | 55 +++++++-- mlx/backend/metal/jit/indexing.h | 12 +- mlx/backend/metal/kernels/indexing/gather.h | 6 + .../metal/kernels/indexing/gather_axis.h | 7 ++ .../metal/kernels/indexing/gather_front.h | 5 + mlx/backend/metal/kernels/indexing/indexing.h | 10 ++ mlx/backend/metal/kernels/indexing/scatter.h | 6 + .../metal/kernels/indexing/scatter_axis.h | 7 ++ mlx/failure.cpp | 24 ++++ mlx/failure.h | 18 +++ mlx/mlx.h | 1 + python/src/indexing.cpp | 18 ++- python/tests/test_array.py | 66 +++++++++- tests/CMakeLists.txt | 1 + tests/bounds_checks_tests.cpp | 116 ++++++++++++++++++ 22 files changed, 442 insertions(+), 18 deletions(-) create mode 100644 mlx/backend/gpu/failure.h create mode 100644 mlx/backend/metal/failure.cpp create mode 100644 mlx/backend/metal/failure.h create mode 100644 mlx/failure.cpp create mode 100644 mlx/failure.h create mode 100644 tests/bounds_checks_tests.cpp diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index 9c0fd38899..bc8e6e9267 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -7,6 +7,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/dtype_utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/export.cpp ${CMAKE_CURRENT_SOURCE_DIR}/einsum.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/failure.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fast.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp diff --git a/mlx/array.cpp b/mlx/array.cpp index 1769f0f84b..d7e67aa9df 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -3,6 +3,7 @@ #include #include "mlx/array.h" +#include "mlx/failure.h" #include "mlx/ops.h" #include "mlx/primitives.h" #include "mlx/transforms.h" @@ -132,7 +133,7 @@ bool array::is_available() const { if (status() == Status::available) { return true; } else if ( - status() == Status::evaluated && + status() == Status::evaluated && !global_failure() && (!event().valid() || event().is_signaled())) { set_status(Status::available); return true; @@ -146,6 +147,10 @@ void array::wait() { event().wait(); detach_event(); } + if (global_failure()) { + reset_global_failure(); + throw std::out_of_range("Array index out of bounds"); + } set_status(Status::available); } } diff --git a/mlx/backend/gpu/failure.h b/mlx/backend/gpu/failure.h new file mode 100644 index 0000000000..1e5e7a1b60 --- /dev/null +++ b/mlx/backend/gpu/failure.h @@ -0,0 +1,10 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +namespace mlx::core::gpu { + +void reset_failure(); +bool has_failure(); + +} // namespace mlx::core::gpu diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt index 4074e7b1e9..7511f9ce2a 100644 --- a/mlx/backend/metal/CMakeLists.txt +++ b/mlx/backend/metal/CMakeLists.txt @@ -112,6 +112,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/event.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/failure.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/metal/eval.cpp b/mlx/backend/metal/eval.cpp index bd58a691a9..919127188e 100644 --- a/mlx/backend/metal/eval.cpp +++ b/mlx/backend/metal/eval.cpp @@ -4,6 +4,7 @@ #include "mlx/backend/gpu/eval.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/utils.h" +#include "mlx/failure.h" #include "mlx/primitives.h" #include "mlx/scheduler.h" @@ -25,12 +26,17 @@ inline void check_error(MTL::CommandBuffer* cbuf) { } void eval(array& arr) { + if (global_failure()) { + return; + } + auto pool = metal::new_scoped_memory_pool(); auto s = arr.primitive().stream(); auto& d = metal::device(s.device); auto command_buffer = d.get_command_buffer(s.index); auto outputs = arr.outputs(); + { // If the array is a tracer hold a reference // to its inputs so they don't get donated diff --git a/mlx/backend/metal/failure.cpp b/mlx/backend/metal/failure.cpp new file mode 100644 index 0000000000..dd090fb9b3 --- /dev/null +++ b/mlx/backend/metal/failure.cpp @@ -0,0 +1,72 @@ +// Copyright © 2026 Apple Inc. + +#include +#include + +#include "mlx/allocator.h" +#include "mlx/backend/gpu/failure.h" +#include "mlx/backend/metal/failure.h" +#include "mlx/failure.h" + +namespace mlx::core { + +namespace { + +class Failure { + public: + static Failure& get() { + static Failure instance; + return instance; + } + + Failure(const Failure&) = delete; + Failure& operator=(const Failure&) = delete; + + allocator::Buffer& buffer() { + return buffer_; + } + + void reset() { + atomic_ptr()->store(FailureCode::NoFailure, std::memory_order_relaxed); + } + + FailureCode value() { + return atomic_ptr()->load(std::memory_order_relaxed); + } + + private: + Failure() : buffer_(allocator::malloc(sizeof(int32_t))) { + reset(); + } + ~Failure() = default; + + std::atomic* atomic_ptr() { + return reinterpret_cast*>(buffer_.raw_ptr()); + } + + allocator::Buffer buffer_; +}; + +} // namespace + +namespace gpu { + +void reset_failure() { + Failure::get().reset(); +} + +bool has_failure() { + return Failure::get().value() != FailureCode::NoFailure; +} + +} // namespace gpu + +namespace metal { + +MTL::Buffer* get_failure_buffer() { + return static_cast(Failure::get().buffer().ptr()); +} + +} // namespace metal + +} // namespace mlx::core diff --git a/mlx/backend/metal/failure.h b/mlx/backend/metal/failure.h new file mode 100644 index 0000000000..b51e56bdd7 --- /dev/null +++ b/mlx/backend/metal/failure.h @@ -0,0 +1,11 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include "Metal/MTLBuffer.hpp" + +namespace mlx::core::metal { + +MTL::Buffer* get_failure_buffer(); + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp index 48311dc6ae..e552d912d9 100644 --- a/mlx/backend/metal/indexing.cpp +++ b/mlx/backend/metal/indexing.cpp @@ -6,12 +6,14 @@ #include "mlx/backend/common/utils.h" #include "mlx/backend/gpu/copy.h" #include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/failure.h" #include "mlx/backend/metal/jit/includes.h" #include "mlx/backend/metal/jit/indexing.h" #include "mlx/backend/metal/kernels.h" #include "mlx/backend/metal/scan.h" #include "mlx/backend/metal/utils.h" #include "mlx/dtype.h" +#include "mlx/failure.h" #include "mlx/primitives.h" #include "mlx/utils.h" @@ -67,6 +69,8 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; + auto* global_failure = metal::get_failure_buffer(); + if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 && inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) { int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1; @@ -80,7 +84,10 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { - std::string kernel_source = metal::utils(); + std::string kernel_source = fmt::format( + "#define BOUNDS_FAILURE {}\n", + static_cast(FailureCode::BoundsFailure)); + kernel_source += metal::utils(); kernel_source += metal::gather_front(); kernel_source += get_template_definition( kernel_name, @@ -107,6 +114,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_output_array(out, 2); compute_encoder.set_bytes(slice_size, 3); compute_encoder.set_bytes(src.shape(0), 4); + compute_encoder.set_buffer(global_failure, 5, 0); compute_encoder.dispatch_threads(grid_dims, group_dims); return; @@ -125,7 +133,10 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { - std::string kernel_source = metal::utils(); + std::string kernel_source = fmt::format( + "#define BOUNDS_FAILURE {}\n", + static_cast(FailureCode::BoundsFailure)); + kernel_source += metal::utils(); kernel_source += metal::gather(); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = @@ -193,14 +204,17 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_vector_bytes(slice_sizes_, 5); compute_encoder.set_vector_bytes(axes_, 6); + // Set failure buffer + compute_encoder.set_buffer(global_failure, 7, 0); + // Set index info // // We don't need to check for empty idx_shapes because gather has a // idx_ndim == 0 specialization - compute_encoder.set_vector_bytes(idx_shapes, 7); - compute_encoder.set_vector_bytes(idx_strides, 8); - compute_encoder.set_vector_bytes(idx_contigs, 9); - compute_encoder.set_bytes(idx_ndim, 10); + compute_encoder.set_vector_bytes(idx_shapes, 8); + compute_encoder.set_vector_bytes(idx_strides, 9); + compute_encoder.set_vector_bytes(idx_contigs, 10); + compute_encoder.set_bytes(idx_ndim, 11); // Set index buffers for (int i = 0; i < nidx; ++i) { @@ -301,7 +315,10 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { std::string lib_name = kernel_name; auto lib = d.get_library(lib_name, [&]() { - std::string kernel_source = metal::utils(); + std::string kernel_source = fmt::format( + "#define BOUNDS_FAILURE {}\n", + static_cast(FailureCode::BoundsFailure)); + kernel_source += metal::utils(); concatenate(kernel_source, metal::reduce_utils(), metal::scatter()); std::string out_type_str = get_type_string(out.dtype()); @@ -352,6 +369,8 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_compute_pipeline_state(kernel); + auto* global_failure = metal::get_failure_buffer(); + // Set all the buffers compute_encoder.set_input_array(upd, 1); compute_encoder.set_output_array(out, 2); @@ -421,6 +440,7 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_vector_bytes(idx_contigs, 13); compute_encoder.set_bytes(idx_ndim, 14); compute_encoder.set_bytes(idx_size, 15); + compute_encoder.set_buffer(global_failure, 16, 0); // Set index buffers for (int i = 0; i < nidx; ++i) { @@ -465,7 +485,10 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { kernel_name += idx.flags().row_contiguous ? "c" : "nc"; auto lib = d.get_library(lib_name, [&]() { - std::string kernel_source = metal::utils(); + std::string kernel_source = fmt::format( + "#define BOUNDS_FAILURE {}\n", + static_cast(FailureCode::BoundsFailure)); + kernel_source += metal::utils(); kernel_source += metal::gather_axis(); std::string out_type_str = get_type_string(out.dtype()); std::string idx_type_str = get_type_string(idx.dtype()); @@ -517,6 +540,9 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_bytes(src.strides(axis_), 9); compute_encoder.set_bytes(idx.strides(axis_), 10); + auto* global_failure = metal::get_failure_buffer(); + compute_encoder.set_buffer(global_failure, 11, 0); + compute_encoder.dispatch_threads(grid_dims, group_dims); } @@ -569,7 +595,10 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { kernel_name += idx.flags().row_contiguous ? "c" : "nc"; auto lib = d.get_library(lib_name, [&]() { - std::string kernel_source = metal::utils(); + std::string kernel_source = fmt::format( + "#define BOUNDS_FAILURE {}\n", + static_cast(FailureCode::BoundsFailure)); + kernel_source += metal::utils(); kernel_source += metal::reduce_utils(); kernel_source += metal::scatter_axis(); std::string out_type_str = get_type_string(out.dtype()); @@ -641,6 +670,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { compute_encoder.set_bytes(upd.strides(axis_), 9); compute_encoder.set_bytes(idx.strides(axis_), 10); + auto* global_failure = metal::get_failure_buffer(); + compute_encoder.set_buffer(global_failure, 11, 0); + compute_encoder.dispatch_threads(grid_dims, group_dims); } @@ -695,7 +727,10 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { fmt::format("{}_{}_{}", kBaseName, dtype_tag, contiguous); auto lib = d.get_library(kernel_name, [&]() { - std::string source = metal::utils(); + std::string source = fmt::format( + "#define BOUNDS_FAILURE {}\n", + static_cast(FailureCode::BoundsFailure)); + source += metal::utils(); source += metal::masked_scatter(); source += fmt::format(masked_assign_kernel, kernel_name, value_type, contiguous); diff --git a/mlx/backend/metal/jit/indexing.h b/mlx/backend/metal/jit/indexing.h index fa141fccf5..ade0a67652 100644 --- a/mlx/backend/metal/jit/indexing.h +++ b/mlx/backend/metal/jit/indexing.h @@ -9,10 +9,11 @@ constexpr std::string_view gather_kernels = R"( const constant size_t& src_ndim [[buffer(4)]], const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], - const constant int* idx_shapes [[buffer(7)]], - const constant int64_t* idx_strides [[buffer(8)]], - const constant bool* idx_contigs [[buffer(9)]], - const constant int& idx_ndim [[buffer(10)]], + device atomic* global_failure [[buffer(7)]], + const constant int* idx_shapes [[buffer(8)]], + const constant int64_t* idx_strides [[buffer(9)]], + const constant bool* idx_contigs [[buffer(10)]], + const constant int& idx_ndim [[buffer(11)]], {4} uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) {{ @@ -27,6 +28,7 @@ constexpr std::string_view gather_kernels = R"( src_ndim, slice_sizes, axes, + global_failure, idxs, index, grid_dim); @@ -50,6 +52,7 @@ constexpr std::string_view scatter_kernels = R"( const constant bool* idx_contigs [[buffer(13)]], const constant int& idx_ndim [[buffer(14)]], const constant size_t& idx_size [[buffer(15)]], + device atomic* global_failure [[buffer(16)]], {5} uint2 gid [[thread_position_in_grid]]) {{ Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; @@ -66,6 +69,7 @@ constexpr std::string_view scatter_kernels = R"( out_ndim, axes, idx_size, + global_failure, idxs, gid); }} diff --git a/mlx/backend/metal/kernels/indexing/gather.h b/mlx/backend/metal/kernels/indexing/gather.h index 8b93c01679..1d52ed0a78 100644 --- a/mlx/backend/metal/kernels/indexing/gather.h +++ b/mlx/backend/metal/kernels/indexing/gather.h @@ -13,6 +13,7 @@ METAL_FUNC void gather_impl( const constant size_t& src_ndim [[buffer(4)]], const constant int* slice_sizes [[buffer(5)]], const constant int* axes [[buffer(6)]], + device atomic* global_failure [[buffer(7)]], const thread Indices& indices, uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { @@ -35,6 +36,11 @@ METAL_FUNC void gather_impl( } auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); + + if (!check_bounds(idx_val, src_shape[ax], global_failure)) { + return; + } + src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); } diff --git a/mlx/backend/metal/kernels/indexing/gather_axis.h b/mlx/backend/metal/kernels/indexing/gather_axis.h index bf490ade06..7ad623fe79 100644 --- a/mlx/backend/metal/kernels/indexing/gather_axis.h +++ b/mlx/backend/metal/kernels/indexing/gather_axis.h @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/metal/kernels/indexing/indexing.h" + template [[kernel]] void gather_axis( const device T* src [[buffer(0)]], @@ -15,6 +17,7 @@ template const constant int& axis_size [[buffer(8)]], const constant size_t& src_ax_stride [[buffer(9)]], const constant size_t& idx_ax_stride [[buffer(10)]], + device atomic* global_failure [[buffer(11)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { LocT elem_idx = index.z * static_cast(grid_dim.x); @@ -32,6 +35,10 @@ template idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val; } + if (!check_bounds(idx_val, axis_size, global_failure)) { + return; + } + LocT src_idx = idx_val * static_cast(src_ax_stride); if (SrcC) { src_idx += elem_idx * axis_size + index.x; diff --git a/mlx/backend/metal/kernels/indexing/gather_front.h b/mlx/backend/metal/kernels/indexing/gather_front.h index 1389e4c621..1ea15652de 100644 --- a/mlx/backend/metal/kernels/indexing/gather_front.h +++ b/mlx/backend/metal/kernels/indexing/gather_front.h @@ -11,12 +11,17 @@ template device T* out, const constant int64_t& stride, const constant int& size, + device atomic* global_failure, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { auto idx = offset_neg_idx(indices[index.y], size); LocT src_idx = static_cast(stride) * idx; LocT out_idx = static_cast(stride) * index.y; + if (!check_bounds(idx, size, global_failure)) { + return; + } + int s_idx = N * index.x; for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { out[out_idx + s_idx] = src[src_idx + s_idx]; diff --git a/mlx/backend/metal/kernels/indexing/indexing.h b/mlx/backend/metal/kernels/indexing/indexing.h index 2a4b4f9298..b0817935c0 100644 --- a/mlx/backend/metal/kernels/indexing/indexing.h +++ b/mlx/backend/metal/kernels/indexing/indexing.h @@ -21,3 +21,13 @@ METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { return (idx < 0) ? idx + size : idx; } } + +template +METAL_FUNC bool +check_bounds(IdxT idx, int size, device atomic* global_failure) { + if (idx < 0 || idx >= size) { + atomic_store_explicit(global_failure, BOUNDS_FAILURE, memory_order_relaxed); + return false; + } + return true; +} diff --git a/mlx/backend/metal/kernels/indexing/scatter.h b/mlx/backend/metal/kernels/indexing/scatter.h index f0217b3369..0c10661f82 100644 --- a/mlx/backend/metal/kernels/indexing/scatter.h +++ b/mlx/backend/metal/kernels/indexing/scatter.h @@ -24,6 +24,7 @@ METAL_FUNC void scatter_impl( const constant size_t& out_ndim, const constant int* axes, const constant size_t& idx_size, + device atomic* global_failure, const thread Indices& indices, uint2 gid [[thread_position_in_grid]]) { Op op; @@ -47,6 +48,11 @@ METAL_FUNC void scatter_impl( indices.ndim); auto ax = axes[i]; auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + + if (!check_bounds(idx_val, out_shape[ax], global_failure)) { + return; + } + out_idx += static_cast(idx_val) * static_cast(out_strides[ax]); } diff --git a/mlx/backend/metal/kernels/indexing/scatter_axis.h b/mlx/backend/metal/kernels/indexing/scatter_axis.h index 73fd7ab4a3..652a584a16 100644 --- a/mlx/backend/metal/kernels/indexing/scatter_axis.h +++ b/mlx/backend/metal/kernels/indexing/scatter_axis.h @@ -2,6 +2,8 @@ #pragma once +#include "mlx/backend/metal/kernels/indexing/indexing.h" + template < typename T, typename IdxT, @@ -21,6 +23,7 @@ template < const constant int& out_axis_size [[buffer(8)]], const constant size_t& upd_ax_stride [[buffer(9)]], const constant size_t& idx_ax_stride [[buffer(10)]], + device atomic* global_failure [[buffer(11)]], uint3 index [[thread_position_in_grid]], uint3 grid_dim [[threads_per_grid]]) { Op op; @@ -39,6 +42,10 @@ template < idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val; } + if (!check_bounds(idx_val, out_axis_size, global_failure)) { + return; + } + LocT upd_idx = index.y * static_cast(upd_ax_stride); if (UpdC) { upd_idx += elem_idx * grid_dim.y + index.x; diff --git a/mlx/failure.cpp b/mlx/failure.cpp new file mode 100644 index 0000000000..49b3ca4747 --- /dev/null +++ b/mlx/failure.cpp @@ -0,0 +1,24 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/backend/gpu/failure.h" +#include "mlx/backend/gpu/device_info.h" +#include "mlx/failure.h" + +namespace mlx::core { + +void reset_global_failure() { + // TODO also reset CPU failure. + if (gpu::is_available()) { + gpu::reset_failure(); + } +} + +bool global_failure() { + // TODO also check CPU failure. + if (gpu::is_available()) { + return gpu::has_failure(); + } + return false; +} + +} // namespace mlx::core diff --git a/mlx/failure.h b/mlx/failure.h new file mode 100644 index 0000000000..eec3b7d541 --- /dev/null +++ b/mlx/failure.h @@ -0,0 +1,18 @@ +// Copyright © 2026 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +enum class FailureCode : int32_t { + NoFailure = -1, + BoundsFailure, +}; + +void reset_global_failure(); + +bool global_failure(); + +} // namespace mlx::core diff --git a/mlx/mlx.h b/mlx/mlx.h index eda7333d5a..0c244cc4a2 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -12,6 +12,7 @@ #include "mlx/distributed/ops.h" #include "mlx/einsum.h" #include "mlx/export.h" +#include "mlx/failure.h" #include "mlx/fast.h" #include "mlx/fft.h" #include "mlx/io.h" diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 564c4cb45b..b4870d06d9 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -813,7 +813,14 @@ auto mlx_slice_update( throw std::invalid_argument(msg.str()); } auto idx = nb::cast(obj); - idx = idx < 0 ? idx + stops[0] : idx; + auto axis_size = stops[0]; + if (idx < -axis_size || idx >= axis_size) { + std::ostringstream msg; + msg << "Index " << idx << " is out of bounds for axis 0 with size " + << axis_size << "."; + throw std::out_of_range(msg.str()); + } + idx = idx < 0 ? idx + axis_size : idx; starts[0] = idx; stops[0] = idx + 1; auto out = slice_update( @@ -874,7 +881,14 @@ auto mlx_slice_update( upd_ax--; } else if (nb::isinstance(pyidx)) { int st = nb::cast(pyidx); - st = (st < 0) ? st + src.shape(i) : st; + int axis_size = src.shape(ax); + if (st < -axis_size || st >= axis_size) { + std::ostringstream msg; + msg << "Index " << st << " is out of bounds for axis " << ax + << " with size " << axis_size << "."; + throw std::out_of_range(msg.str()); + } + st = (st < 0) ? st + axis_size : st; starts[ax] = st; stops[ax] = st + 1; if (upd_ax >= 0) { diff --git a/python/tests/test_array.py b/python/tests/test_array.py index 4efed9dac9..f731b80868 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1100,6 +1100,46 @@ def check_slices(arr_np, *idx_np): a_mlx = mx.array(a_np) self.assertTrue(np.array_equal(a_np[2:-1, 0], np.array(a_mlx[2:-1, 0]))) + # Out of bounds indexing + a = mx.arange(10) + with self.assertRaises(IndexError): + a[10].item() + with self.assertRaises(IndexError): + a[-11].item() + with self.assertRaises(IndexError): + mx.eval(a[1 << 30]) + + a = mx.zeros((5, 5)) + idx = mx.array([5]) + with self.assertRaises(IndexError): + mx.eval(a[idx, 0]) + with self.assertRaises(IndexError): + mx.eval(a[0, idx]) + with self.assertRaises(IndexError): + mx.eval(a[1 << 30]) + + def test_scatter(self): + a = mx.array([1, 2, 3]) + with self.assertRaises(IndexError): + mx.eval(a.at[mx.array([100])].add(mx.array([10]))) + with self.assertRaises(IndexError): + mx.eval(a.at[mx.array([-6])].add(mx.array([10]))) + + a = mx.array([[1, 2, 3], [4, 5, 6]]) + indices = mx.array([0, 10]) + updates = mx.array([[0, 0, 0], [0, 0, 0]]) + with self.assertRaises(IndexError): + mx.eval(a.at[indices].add(updates)) + with self.assertRaises(IndexError): + a[indices] = updates + mx.eval(a) + + def test_take_along_axis_bounds(self): + a = mx.array([[1, 2, 3], [4, 5, 6]]) + indices = mx.array([[0, 10], [0, 0]]) + with self.assertRaises(IndexError): + mx.eval(mx.take_along_axis(a, indices, axis=1)) + def test_indexing_grad(self): x = mx.array([[1, 2], [3, 4]]).astype(mx.float32) ind = mx.array([0, 1, 0]).astype(mx.float32) @@ -1125,6 +1165,14 @@ def test_setitem(self): a[-1] = 2 self.assertEqual(a.tolist(), [2, 2, 2]) + with self.assertRaises(IndexError): + a[4] = 1 + mx.eval(a) + + with self.assertRaises(IndexError): + a[-4] = 1 + mx.eval(a) + a[0] = mx.array([[[1]]]) self.assertEqual(a.tolist(), [1, 2, 2]) @@ -1140,7 +1188,11 @@ def test_setitem(self): a[0:2] = 3 self.assertEqual(a.tolist(), [3, 3, 1]) - a[0:3] = 4 + a[0:3] = 5 + self.assertEqual(a.tolist(), [5, 5, 5]) + + # Slices clip, so this should not raise IndexError + a[0:4] = 4 self.assertEqual(a.tolist(), [4, 4, 4]) a[0:1] = mx.array(0) @@ -1158,6 +1210,14 @@ def test_setitem(self): a[:] = mx.array([[[[1, 1, 1]]]]) self.assertEqual(a.tolist(), [1, 1, 1]) + a = mx.array([[1, 1, 1], [1, 1, 1]]) + with self.assertRaises(IndexError): + a[0, 10] = 0 + mx.eval(a) + with self.assertRaises(IndexError): + a[None, 2] = 1 + mx.eval(a) + # Array slices def check_slices(arr_np, update_np, *idx_np): arr_mlx = mx.array(arr_np) @@ -1379,6 +1439,10 @@ def test_array_at(self): a = a.at[1].add(2) self.assertEqual(a.tolist(), [0, 3, 2]) + with self.assertRaises(IndexError): + a = a.at[4].add(2) + mx.eval(a) + a = a.at[mx.array([0, 0, 0, 0])].add(1) self.assertEqual(a.tolist(), [4, 3, 2]) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2a4a41c6b6..830b41519b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -18,6 +18,7 @@ target_sources( array_tests.cpp arg_reduce_tests.cpp autograd_tests.cpp + bounds_checks_tests.cpp blas_tests.cpp compile_tests.cpp custom_vjp_tests.cpp diff --git a/tests/bounds_checks_tests.cpp b/tests/bounds_checks_tests.cpp new file mode 100644 index 0000000000..8b9ad3c9fd --- /dev/null +++ b/tests/bounds_checks_tests.cpp @@ -0,0 +1,116 @@ +// Copyright © 2024 Apple Inc. + +#include "doctest/doctest.h" +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test bounds checks gather") { + if (!metal::is_available()) { + return; + } + + auto stream = default_stream(Device::gpu); + reset_global_failure(); + + // 1. 1D Array failure (gather_front kernel) + { + auto src = array({10, 20, 30, 40, 50}, {5}); + auto bad_indices = array({0, 2, 5}, {3}); + auto bad_op = gather(src, bad_indices, 0, {1}); + CHECK_THROWS_AS(eval(bad_op), std::out_of_range); + reset_global_failure(); + } + + // 2. 2D Array failure (general gather kernel) + { + auto src_2d = array({1, 2, 3, 4, 5, 6}, {2, 3}); + auto bad_indices = array({0, 30}, {2}); + auto bad_op = gather(src_2d, bad_indices, 0, {1, 2}); + CHECK_THROWS_AS(eval(bad_op), std::out_of_range); + reset_global_failure(); + } + + // 3. Valid 2D gather + { + auto src_2d = array({1, 2, 3, 4, 5, 6}, {2, 3}); + auto valid_indices = array({0, 1, 0}, {3}); + auto op = gather(src_2d, valid_indices, 0, {1, 3}); + eval(op); + CHECK(op.is_available()); + } +} + +TEST_CASE("test bounds checks dependent op failure propagation") { + if (!metal::is_available()) { + return; + } + + auto stream = default_stream(Device::gpu); + reset_global_failure(); + + auto src = array({1, 2, 3, 4, 5}, {5}); + + auto bad_indices = array({100}, {1}); + auto fail_op = gather(src, bad_indices, 0, {3}); + auto next_op = gather(fail_op, array({0}, {1}), 0, {1, 3}); + + CHECK_THROWS_AS(eval(next_op), std::out_of_range); + + reset_global_failure(); + auto valid_op = gather(src, array({0}, {1}), 0, {3}); + eval(valid_op); + CHECK(valid_op.is_available()); +} + +TEST_CASE("test bounds checks scatter") { + if (!metal::is_available()) { + return; + } + + auto stream = default_stream(Device::gpu); + reset_global_failure(); + + auto dst = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {5}); + + // 1. 1D Scatter failure + { + auto indices = array({100}, {1}); + auto updates = array({10.0f}, {1, 1}); + auto bad_op = scatter(dst, indices, updates, 0); + CHECK_THROWS_AS(eval(bad_op), std::out_of_range); + reset_global_failure(); + } + + // 2. 2D Scatter failure + { + auto dst_2d = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); + auto indices = std::vector{array({0, 10}, {2})}; + auto updates = reshape( + array({10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}, {2, 3}), {2, 1, 3}); + auto bad_op = scatter(dst_2d, indices, updates, {0}); + CHECK_THROWS_AS(eval(bad_op), std::out_of_range); + reset_global_failure(); + } + + // 3. Valid scatter + { + auto updates = reshape(array({10.0f, 20.0f}, {2}), {2, 1}); + auto indices = array({0, 4}, {2}); + auto op = scatter(dst, indices, updates, 0); + eval(op); + CHECK(op.is_available()); + auto expected = std::vector{10.0f, 2.0f, 3.0f, 4.0f, 20.0f}; + CHECK(array_equal(op, array(expected.data(), {5}, float32)).item()); + } + + // 4. scatter_add failure + { + auto dst = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {5}); + auto indices = array({100}, {1}); + auto updates = array({1.0f}, {1, 1}); + auto bad_op = scatter_add(dst, indices, updates, 0); + CHECK_THROWS_AS(eval(bad_op), std::out_of_range); + reset_global_failure(); + } +}