Skip to content

Use cblas_dgemm_batch when available instead of manual batch loop #125

@shinaoka

Description

@shinaoka

Summary

strided-einsum2/src/bgemm_blas.rs currently loops over batch dimensions and calls cblas_dgemm one slice at a time (the do_batch closure). OpenBLAS 0.3.29+ provides cblas_dgemm_batch (pointer-array variant) which handles the loop internally and may enable BLAS-level optimizations.

Current state

  • cblas_dgemm_batch (pointer-array version): available in OpenBLAS >= 0.3.29
  • cblas_dgemm_batch_strided: NOT in 0.3.29 (added later), but not needed — the pointer-array version is more flexible
  • cblas-sys crate (0.2.0 / 0.3.0): does NOT export cblas_dgemm_batch bindings (it's a BLAS-like extension, not standard CBLAS)
  • cblas-inject feature: provides runtime-registered CBLAS fallback — currently only has single GEMM

Proposal

  1. Add cblas_dgemm_batch / cblas_zgemm_batch via extern "C" declarations in bgemm_blas.rs
  2. Add cblas_dgemm_batch / cblas_zgemm_batch fallback to cblas-inject — implement as a loop over individual GEMM calls, so that the batch API is always available regardless of the underlying BLAS library
  3. Use the batch API uniformly in bgemm_contiguous_into() — no need for feature gates or runtime detection since cblas-inject provides the fallback
  4. For real BLAS libraries (OpenBLAS, MKL), the native cblas_dgemm_batch is called; for cblas-inject, the fallback loop runs

API signature (OpenBLAS/MKL)

void cblas_dgemm_batch(
    CBLAS_LAYOUT layout,
    CBLAS_TRANSPOSE *transa_array,
    CBLAS_TRANSPOSE *transb_array,
    MKL_INT *m_array, MKL_INT *n_array, MKL_INT *k_array,
    double *alpha_array,
    const double **a_array, MKL_INT *lda_array,
    const double **b_array, MKL_INT *ldb_array,
    double *beta_array,
    double **c_array, MKL_INT *ldc_array,
    MKL_INT group_count,
    MKL_INT *group_size
);

For uniform batches: group_count=1, group_size=[batch_count].

cblas-inject fallback implementation

// In cblas-inject crate
pub unsafe extern "C" fn cblas_dgemm_batch(
    layout: ..., transa_array: ..., transb_array: ...,
    m_array: ..., n_array: ..., k_array: ...,
    alpha_array: ...,
    a_array: ..., lda_array: ...,
    b_array: ..., ldb_array: ...,
    beta_array: ...,
    c_array: ..., ldc_array: ...,
    group_count: ..., group_size: ...
) {
    // Loop over groups, then over matrices in each group,
    // calling cblas_dgemm for each individual matrix multiply
    for group in 0..group_count {
        for i in 0..group_size[group] {
            cblas_dgemm(layout, transa_array[group], transb_array[group], ...);
        }
    }
}

Expected impact

Likely small for einsum2 (batch dimensions are typically small). Main benefit is cleaner code and potential BLAS-internal parallelization of the batch loop on capable libraries.

Risk

Low. cblas-inject fallback ensures correctness everywhere; native BLAS libraries get the optimized path.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions