Add bounds checking on GPU #3091
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
I propose to slow things down a bit. This is a draft implementation of array bounds checking in the Metal backend. The scatter and gather benchmarks show a 2% median runtime overhead on an 8-core M4 GPU (see results below). Is this something you'd be interested in merging? I'll probably be working on it for fun, so I'd appreciate your thoughts either way :)
The low overhead is due to @athas's strategy described in Bounds Checking on GPU: errors are propagated from the GPU to the CPU by writing a single global 32-bit integer. This is just a draft to play around with the essence of it. I've matched the behaviour against pytorch's MPS backend, which appears to prioritize performance over safety. I don't yet report the out-of-bounds index value or the axis size, but it'd be straightforward to add this. Also I assume that gathers and scatters are never fused. A non-draft PR should add bounds checking to the CPU and CUDA backends too, of course.
Details
The strategy/implementation has limitations:
The global error is written/read/thrown asynchronously and really means "there is at least one error", so handling the thrown error may leave us in a bad state. I check the global error in
array::wait()and reset it right before throwing. (Otherwise the error would persist, being thrown on every eval when MLX is used in the Python REPL.) So the program must terminate on index errors.The paper prevents memory corruption by checking global failure in the prelude of every kernel. I didn't implement this, so kernels that have been enqueued before the global failure was set may operate on invalid inputs such as uninitialized arrays. This means all arrays have to be considered invalid on an index error. Still, there can be no out-of-bounds reading or writing and an index error will get reported to the user. I'm unsure whether pytorch allows memory corruption, I did not study their implementation.
Possible improvements:
jvp/vjpoperations are safe if the corresponding primal has been checked).gather_mm.Related to #206.
Evaluation
I ran
benchmarks/python/gather_bench.pyandbenchmarks/python/scatter_bench.pyto get the below results.Summary disregarding the -16% test (which I guess is just noise):
Average Overhead: 2.85%
Max Overhead: 8.61%
Median Overhead: 2.34%
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes