Skip to content

einsum2 batched GEMM has 3.5x overhead vs Julia for small-m contractions #115

@shinaoka

Description

@shinaoka

Summary

einsum2_into_owned is 3.5x slower than Julia's direct BLAS.gemm! for the GEMM phase (after data is already contiguous), when m is small (m=4).

Benchmark (AMD EPYC 7713P)

Step 408 of tensornetwork_permutation_light_415: m=4, k=256, n=8192, 8 batches.

GEMM (contiguous data) 1T 4T
Rust einsum2_into_owned (blas) 63 ms 62 ms
Julia mul!BLAS.gemm! 18 ms 5 ms

Rust 4T shows no speedup for GEMM, while Julia gets 3.6x speedup.

Reproduction

# Rust (measures "einsum2 (contiguous, ~GEMM)" line)
RAYON_NUM_THREADS=1 OMP_NUM_THREADS=1 \
  cargo run --release --no-default-features --features blas --bin step408_bench

# Julia (measures "BLAS gemm only" line)
OPENBLAS_NUM_THREADS=1 julia --project=. micro_bench/step408_fair.jl

Analysis: two layers of overhead

Layer 1: einsum2_into_owned wrapper cost (both backends)

On every invocation, einsum2_into_owned performs:

  • Einsum2Plan::new — axis classification and permutation computation via linear scans
  • validate_dimensions — scans all axis groups against all operand label arrays
  • prepare_input_ownedtry_fuse_group, REQUIRES_UNIT_STRIDE checks, allocation

This overhead is independent of the GEMM backend (faer or blas) and dominates for small contractions.

Layer 2: per-batch GEMM dispatch (backend-specific)

The 8 batches are dispatched as 8 separate GEMM calls. Each call has backend-specific overhead.

Approach: Adopt OMEinsum.jl's strategy (both backends)

OMEinsum.jl (BatchedRoutines.jl):

for batch in 1:nb
    ccall(dgemm_, ..., ptrA + batch_offset, lda, ptrB + batch_offset, ldb, ...)
end

Key: plan is built once, then GEMM is a tight loop over raw pointers with stride parameters.

Proposed fix

  1. Reduce per-call overhead (both backends): Cache or hoist Einsum2Plan construction and validation outside the batch loop. The current bgemm_contiguous_into already receives pre-prepared operands, but the per-batch dispatch in the backend still has overhead.

  2. Batch loop with stride passthrough (backend-specific):

    • blas backend: Loop over batches calling cblas_dgemm directly with lda/ldb stride parameters, advancing raw pointers by batch stride.
    • faer backend: Loop over batches calling faer::linalg::matmul::matmul with MatRef/MatMut (which natively carry stride info), advancing by batch stride. faer already handles non-unit strides via its MatRef API.
  3. Skip multi-threading for small m: When m ≤ threshold, disable BLAS/faer threading for GEMM (the thread dispatch overhead exceeds the benefit).

Context

Related: #114 (copy_into performance), #116 (pre-permutation dim fusion — the highest impact fix).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions