Skip to content

Conversation

@nhey
Copy link

@nhey nhey commented Feb 2, 2026

Proposed changes

I propose to slow things down a bit. This is a draft implementation of array bounds checking in the Metal backend. The scatter and gather benchmarks show a 2% median runtime overhead on an 8-core M4 GPU (see results below). Is this something you'd be interested in merging? I'll probably be working on it for fun, so I'd appreciate your thoughts either way :)

The low overhead is due to @athas's strategy described in Bounds Checking on GPU: errors are propagated from the GPU to the CPU by writing a single global 32-bit integer. This is just a draft to play around with the essence of it. I've matched the behaviour against pytorch's MPS backend, which appears to prioritize performance over safety. I don't yet report the out-of-bounds index value or the axis size, but it'd be straightforward to add this. Also I assume that gathers and scatters are never fused. A non-draft PR should add bounds checking to the CPU and CUDA backends too, of course.

Details

The strategy/implementation has limitations:

  1. Only one error is recorded and, if there are multiple errors at runtime, it's indeterministic which one.
  2. Handling index errors is disallowed.

The global error is written/read/thrown asynchronously and really means "there is at least one error", so handling the thrown error may leave us in a bad state. I check the global error in array::wait() and reset it right before throwing. (Otherwise the error would persist, being thrown on every eval when MLX is used in the Python REPL.) So the program must terminate on index errors.

  1. My implementation doesn't make indexing safe.

The paper prevents memory corruption by checking global failure in the prelude of every kernel. I didn't implement this, so kernels that have been enqueued before the global failure was set may operate on invalid inputs such as uninitialized arrays. This means all arrays have to be considered invalid on an index error. Still, there can be no out-of-bounds reading or writing and an index error will get reported to the user. I'm unsure whether pytorch allows memory corruption, I did not study their implementation.

Possible improvements:

  • Better error information.
  • Optimization: disable bounds checking for safe gather/scatter calls (e.g., I imagine some jvp/vjp operations are safe if the corresponding primal has been checked).
  • Bounds checking specialized fused kernels like gather_mm.
  • Preventing memory corruption by checking for global failure in every GPU kernel. The paper demonstrates low overhead.
  • As the paper mentions, you can use this method to report any kind of error (e.g., division by zero).

Related to #206.

Evaluation

I ran benchmarks/python/gather_bench.py and benchmarks/python/scatter_bench.py to get the below results.

Benchmark main (ms) bounds (ms) Overhead
Gather: X(100, 64) 3.331 3.409 +2%
Gather: X(100, 1024) 4.903 4.928 +0.5%
Gather: X(4, 1000000) 0.312 0.311 -0.3%
Scatter: Dst(10, 64) 3.800 4.127 +8%
Scatter: Dst(100000, 64) 7.373 7.595 +3%
Scatter: Dst(1000000, 64) 1.185 1.209 +2%
Scatter: Dst(100000,) 0.408 0.403 -1%
Scatter: Dst(200000,) 3.030 3.000 -1%
Scatter: Dst(20000000,) 31.614 31.682 +0.2%
Scatter: Dst(10000, 64) 4.071 4.315 +6%
Scatter: Dst(100, 64) 35.631 38.184 +7%
Scatter: Dst(100, 10000, 64) 49.503 52.397 +6%
Scatter: Dst(10, 100, 100, 21) 210.873 219.162 +4%
Scatter: Dst(1000, 1000, 10) 0.910 0.761 -16%

Summary disregarding the -16% test (which I guess is just noise):

Average Overhead: 2.85%
Max Overhead: 8.61%
Median Overhead: 2.34%

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni
Copy link
Member

awni commented Feb 2, 2026

I think a better and much simpler compromise would be to use Metal logging and only enable this in debug mode. See e.g https://ml-explore.github.io/mlx/build/html/dev/metal_logging.html

@nhey
Copy link
Author

nhey commented Feb 3, 2026

It's a different game if it's a debug only feature. It seems straightforward to change my commit to use logging. I could do that if you want me to.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants