-
Notifications
You must be signed in to change notification settings - Fork 0
Use cblas_dgemm_batch when available instead of manual batch loop #125
Copy link
Copy link
Open
Description
Summary
strided-einsum2/src/bgemm_blas.rs currently loops over batch dimensions and calls cblas_dgemm one slice at a time (the do_batch closure). OpenBLAS 0.3.29+ provides cblas_dgemm_batch (pointer-array variant) which handles the loop internally and may enable BLAS-level optimizations.
Current state
cblas_dgemm_batch(pointer-array version): available in OpenBLAS >= 0.3.29cblas_dgemm_batch_strided: NOT in 0.3.29 (added later), but not needed — the pointer-array version is more flexiblecblas-syscrate (0.2.0 / 0.3.0): does NOT exportcblas_dgemm_batchbindings (it's a BLAS-like extension, not standard CBLAS)cblas-injectfeature: provides runtime-registered CBLAS fallback — currently only has single GEMM
Proposal
- Add
cblas_dgemm_batch/cblas_zgemm_batchviaextern "C"declarations inbgemm_blas.rs - Add
cblas_dgemm_batch/cblas_zgemm_batchfallback tocblas-inject— implement as a loop over individual GEMM calls, so that the batch API is always available regardless of the underlying BLAS library - Use the batch API uniformly in
bgemm_contiguous_into()— no need for feature gates or runtime detection since cblas-inject provides the fallback - For real BLAS libraries (OpenBLAS, MKL), the native
cblas_dgemm_batchis called; for cblas-inject, the fallback loop runs
API signature (OpenBLAS/MKL)
void cblas_dgemm_batch(
CBLAS_LAYOUT layout,
CBLAS_TRANSPOSE *transa_array,
CBLAS_TRANSPOSE *transb_array,
MKL_INT *m_array, MKL_INT *n_array, MKL_INT *k_array,
double *alpha_array,
const double **a_array, MKL_INT *lda_array,
const double **b_array, MKL_INT *ldb_array,
double *beta_array,
double **c_array, MKL_INT *ldc_array,
MKL_INT group_count,
MKL_INT *group_size
);For uniform batches: group_count=1, group_size=[batch_count].
cblas-inject fallback implementation
// In cblas-inject crate
pub unsafe extern "C" fn cblas_dgemm_batch(
layout: ..., transa_array: ..., transb_array: ...,
m_array: ..., n_array: ..., k_array: ...,
alpha_array: ...,
a_array: ..., lda_array: ...,
b_array: ..., ldb_array: ...,
beta_array: ...,
c_array: ..., ldc_array: ...,
group_count: ..., group_size: ...
) {
// Loop over groups, then over matrices in each group,
// calling cblas_dgemm for each individual matrix multiply
for group in 0..group_count {
for i in 0..group_size[group] {
cblas_dgemm(layout, transa_array[group], transb_array[group], ...);
}
}
}Expected impact
Likely small for einsum2 (batch dimensions are typically small). Main benefit is cleaner code and potential BLAS-internal parallelization of the batch loop on capable libraries.
Risk
Low. cblas-inject fallback ensures correctness everywhere; native BLAS libraries get the optimized path.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels