Skip to content

Conversation

@mercush
Copy link

@mercush mercush commented Nov 19, 2025

Proposed changes

Feature: added function metal kernel sparse_matmul_csr for sparse matrix-dense matrix multiplication. The function takes as input a sparse matrix (represented as three arrays representing the row pointers, the column indices, and the values at the nonzero entries) and a dense matrix represented as a dense array with two dimensions. I have implemented the backends for metal and for the cpu but not cuda.

Please let me know if there is any documentation I should update.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@mercush mercush marked this pull request as draft November 19, 2025 16:56
@mercush mercush changed the title Kernel for sparse matrix-dense matrix multiplication [Experiment] Kernel for sparse matrix-dense matrix multiplication Nov 19, 2025
mercush and others added 9 commits November 29, 2025 15:10
- Add MLX_API macro to sparse_matmul_csr declaration in ops.h
- Add Python binding for sparse_matmul_csr in python/src/ops.cpp
These parameters can be inferred from the inputs:
- n_rows = row_ptr.shape(0) - 1 (CSR format invariant)
- n_cols = dense_b.shape(1) (output columns match dense matrix)

This prevents bugs where incorrect parameters could cause out-of-bounds
memory access.
Remove n_rows and n_cols parameters from test calls to match the
updated function signature that infers these values from inputs.
Remove redundant n_rows and n_cols parameters from sparse_matmul_csr
@mercush
Copy link
Author

mercush commented Jan 26, 2026

Metal, Cuda, and cpu backends are passing tests.

@mercush mercush changed the title [Experiment] Kernel for sparse matrix-dense matrix multiplication Kernel for sparse matrix-dense matrix multiplication Jan 26, 2026
@mercush mercush marked this pull request as ready for review January 26, 2026 21:32
@mercush mercush changed the title Kernel for sparse matrix-dense matrix multiplication [Feature] Kernel for sparse matrix-dense matrix multiplication Jan 30, 2026
@mercush
Copy link
Author

mercush commented Feb 1, 2026

Benchmark: spmm can be a lot faster for large matrices with a lot of sparsity.

Metal backend (M2 Pro):

Sparse matmul CSR benchmark (1024x1024, 5% nonzero)
  float16:
Timing sparse (float16) ... 0.42803 msec
Timing dense  (float16) ... 0.65754 msec
  bfloat16:
Timing sparse (bfloat16) ... 0.38357 msec
Timing dense  (bfloat16) ... 0.64950 msec
  float32:
Timing sparse (float32) ... 0.49228 msec
Timing dense  (float32) ... 0.60805 msec
Sparse matmul CSR benchmark (4096x4096, 1% nonzero)
  float16:
Timing sparse (float16) ... 6.89379 msec
Timing dense  (float16) ... 22.62993 msec
  bfloat16:
Timing sparse (bfloat16) ... 6.82182 msec
Timing dense  (bfloat16) ... 26.86476 msec
  float32:
Timing sparse (float32) ... 15.21455 msec
Timing dense  (float32) ... 27.31441 msec

CUDA backend (RTX 5090):

Sparse matmul CSR benchmark (1024x1024, 5% nonzero)
  float16:
Timing sparse (float16) ... 0.05121 msec
Timing dense  (float16) ... 0.05960 msec
  bfloat16:
Timing sparse (bfloat16) ... 0.04969 msec
Timing dense  (bfloat16) ... 0.06564 msec
  float32:
Timing sparse (float32) ... 0.06870 msec
Timing dense  (float32) ... 0.07317 msec
Sparse matmul CSR benchmark (4096x4096, 1% nonzero)
  float16:
Timing sparse (float16) ... 0.24195 msec
Timing dense  (float16) ... 0.72532 msec
  bfloat16:
Timing sparse (bfloat16) ... 0.23913 msec
Timing dense  (bfloat16) ... 0.84904 msec
  float32:
Timing sparse (float32) ... 0.44731 msec
Timing dense  (float32) ... 1.44518 msec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants