fix(metal): support argsort for arrays >1024 elements #3308
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Here is an attempt for #2570. Our ML research project requires it to work. Don't hesitate to comment or contribute in any way.
Summary
per threadgroup limit
GreaterThancomparator for descending sort support in MLX kernelsasort_bigtest on Metal, add vocabulary-size tests (2048, 4096, 32000 elements)Fixes #2570
Problem
The Metal
argsortkernel used bitonic sort withncols_padas threadgroup size. Since Metal limits threadgroups to 1024 threads,arrays with >1024 elements failed silently.
Solution
For arrays >1024 elements, use the existing
call_mlx_arg_sortfunction which implements multi-block merge sort from MLX. Thishandles arbitrary array sizes efficiently.
Supported types for large arrays: BF16, F16, F32, U8, U32, I64
Notes:
Test plan
asort_cpu- existing test passesasort_metal- existing test passesasort_big_cpu- tests 2000 elementsasort_big_metal- now passes (was skipped)asort_vocabulary_cpu- tests 2048, 4096, 32000 elementsasort_vocabulary_metal- tests vocabulary sizes used in LLMscargo test -p candle-core --test tensor_tests --features metal -- asort