-
Notifications
You must be signed in to change notification settings - Fork 1.5k
[Feature] Kernel for sparse matrix-dense matrix multiplication #2796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mercush
wants to merge
19
commits into
ml-explore:main
Choose a base branch
from
mercush:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- Add MLX_API macro to sparse_matmul_csr declaration in ops.h - Add Python binding for sparse_matmul_csr in python/src/ops.cpp
Expose sparse_matmul_csr to Python API
These parameters can be inferred from the inputs: - n_rows = row_ptr.shape(0) - 1 (CSR format invariant) - n_cols = dense_b.shape(1) (output columns match dense matrix) This prevents bugs where incorrect parameters could cause out-of-bounds memory access.
Remove n_rows and n_cols parameters from test calls to match the updated function signature that infers these values from inputs.
Remove redundant n_rows and n_cols parameters from sparse_matmul_csr
Author
|
Metal, Cuda, and cpu backends are passing tests. |
Author
|
Benchmark: spmm can be a lot faster for large matrices with a lot of sparsity. Metal backend (M2 Pro): Sparse matmul CSR benchmark (1024x1024, 5% nonzero)
float16:
Timing sparse (float16) ... 0.42803 msec
Timing dense (float16) ... 0.65754 msec
bfloat16:
Timing sparse (bfloat16) ... 0.38357 msec
Timing dense (bfloat16) ... 0.64950 msec
float32:
Timing sparse (float32) ... 0.49228 msec
Timing dense (float32) ... 0.60805 msec
Sparse matmul CSR benchmark (4096x4096, 1% nonzero)
float16:
Timing sparse (float16) ... 6.89379 msec
Timing dense (float16) ... 22.62993 msec
bfloat16:
Timing sparse (bfloat16) ... 6.82182 msec
Timing dense (bfloat16) ... 26.86476 msec
float32:
Timing sparse (float32) ... 15.21455 msec
Timing dense (float32) ... 27.31441 msecCUDA backend (RTX 5090): Sparse matmul CSR benchmark (1024x1024, 5% nonzero)
float16:
Timing sparse (float16) ... 0.05121 msec
Timing dense (float16) ... 0.05960 msec
bfloat16:
Timing sparse (bfloat16) ... 0.04969 msec
Timing dense (bfloat16) ... 0.06564 msec
float32:
Timing sparse (float32) ... 0.06870 msec
Timing dense (float32) ... 0.07317 msec
Sparse matmul CSR benchmark (4096x4096, 1% nonzero)
float16:
Timing sparse (float16) ... 0.24195 msec
Timing dense (float16) ... 0.72532 msec
bfloat16:
Timing sparse (bfloat16) ... 0.23913 msec
Timing dense (bfloat16) ... 0.84904 msec
float32:
Timing sparse (float32) ... 0.44731 msec
Timing dense (float32) ... 1.44518 msec |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Proposed changes
Feature: added function metal kernel
sparse_matmul_csrfor sparse matrix-dense matrix multiplication. The function takes as input a sparse matrix (represented as three arrays representing the row pointers, the column indices, and the values at the nonzero entries) and a dense matrix represented as a dense array with two dimensions. I have implemented the backends for metal and for the cpu but not cuda.Please let me know if there is any documentation I should update.
Checklist
Put an
xin the boxes that apply.pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes