-
Notifications
You must be signed in to change notification settings - Fork 0
einsum2 batched GEMM has 3.5x overhead vs Julia for small-m contractions #115
Description
Summary
einsum2_into_owned is 3.5x slower than Julia's direct BLAS.gemm! for the GEMM phase (after data is already contiguous), when m is small (m=4).
Benchmark (AMD EPYC 7713P)
Step 408 of tensornetwork_permutation_light_415: m=4, k=256, n=8192, 8 batches.
| GEMM (contiguous data) | 1T | 4T |
|---|---|---|
Rust einsum2_into_owned (blas) |
63 ms | 62 ms |
Julia mul! → BLAS.gemm! |
18 ms | 5 ms |
Rust 4T shows no speedup for GEMM, while Julia gets 3.6x speedup.
Reproduction
# Rust (measures "einsum2 (contiguous, ~GEMM)" line)
RAYON_NUM_THREADS=1 OMP_NUM_THREADS=1 \
cargo run --release --no-default-features --features blas --bin step408_bench
# Julia (measures "BLAS gemm only" line)
OPENBLAS_NUM_THREADS=1 julia --project=. micro_bench/step408_fair.jlAnalysis: two layers of overhead
Layer 1: einsum2_into_owned wrapper cost (both backends)
On every invocation, einsum2_into_owned performs:
Einsum2Plan::new— axis classification and permutation computation via linear scansvalidate_dimensions— scans all axis groups against all operand label arraysprepare_input_owned—try_fuse_group,REQUIRES_UNIT_STRIDEchecks, allocation
This overhead is independent of the GEMM backend (faer or blas) and dominates for small contractions.
Layer 2: per-batch GEMM dispatch (backend-specific)
The 8 batches are dispatched as 8 separate GEMM calls. Each call has backend-specific overhead.
Approach: Adopt OMEinsum.jl's strategy (both backends)
OMEinsum.jl (BatchedRoutines.jl):
for batch in 1:nb
ccall(dgemm_, ..., ptrA + batch_offset, lda, ptrB + batch_offset, ldb, ...)
endKey: plan is built once, then GEMM is a tight loop over raw pointers with stride parameters.
Proposed fix
-
Reduce per-call overhead (both backends): Cache or hoist
Einsum2Planconstruction and validation outside the batch loop. The currentbgemm_contiguous_intoalready receives pre-prepared operands, but the per-batch dispatch in the backend still has overhead. -
Batch loop with stride passthrough (backend-specific):
- blas backend: Loop over batches calling
cblas_dgemmdirectly withlda/ldbstride parameters, advancing raw pointers by batch stride. - faer backend: Loop over batches calling
faer::linalg::matmul::matmulwithMatRef/MatMut(which natively carry stride info), advancing by batch stride. faer already handles non-unit strides via itsMatRefAPI.
- blas backend: Loop over batches calling
-
Skip multi-threading for small m: When
m ≤ threshold, disable BLAS/faer threading for GEMM (the thread dispatch overhead exceeds the benefit).
Context
Related: #114 (copy_into performance), #116 (pre-permutation dim fusion — the highest impact fix).