-
Notifications
You must be signed in to change notification settings - Fork 0
perf: avoid finalize copy in einsum2 by deferring output permutation (lazy permute) #128
Description
Summary
einsum2_dispatch always materializes the output via finalize_into when the output permutation makes the C dimensions non-fusable. This causes a full copy of the output tensor at the end of every contraction step with non-trivial output permutation.
A "lazy permute" approach (as implemented in tenferro-einsum) can avoid this copy entirely for intermediate steps by returning a non-contiguous view instead of a contiguous array. The next step's prepare_input_view/owned handles the non-contiguous input, and if the input happens to be fusable for the next GEMM, the copy is eliminated completely.
Current behavior
In einsum2_dispatch (lib.rs):
// 3. Prepare output
let c_op = prepare_output_view(&mut c_perm, n_lo, n_ro, beta, ...)?;
// 4. GEMM
B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, ...)?;
// 5. Finalize — copies temp → c_perm when output was non-fusable
c_op.finalize_into(&mut c_perm)?;When prepare_output_view detects non-fusable strides, it allocates a temp buffer and finalize_into copies the GEMM result back. For intermediate steps (beta=0), the copy-in is skipped but the copy-back always happens.
Proposed behavior
For intermediate contraction steps (where the output is consumed by a subsequent step), skip finalize_into and instead return the GEMM output as a non-contiguous view with rearranged strides (lazy permute). The downstream step's prepare_input_owned already handles non-contiguous inputs.
Benchmark evidence
Comparing strided-rs vs tenferro-einsum (which implements lazy permute) on gm_queen5_5_3.wcsp:
| Strategy | strided-rs | tenferro | Ratio |
|---|---|---|---|
| opt_flops (148 steps, large intermediates) | 8116ms | 7083ms | 0.87x |
| opt_size (159 steps, small intermediates) | 2426ms | 2753ms | 1.13x |
- opt_flops: tenferro is 13% faster — large intermediates make the finalize copy expensive. Lazy permute avoids ~GB of total copies across 148 steps.
- opt_size: strided-rs is 13% faster — small intermediates make finalize copies cheap, and tenferro's multi-layer dispatch overhead (~2ms/step) dominates.
The opt_flops result demonstrates that eliminating finalize copies can yield significant speedups for workloads with large intermediate tensors.
Affected code
strided-einsum2/src/lib.rs:einsum2_dispatch— thefinalize_intocallstrided-einsum2/src/contiguous.rs:ContiguousOperandMut::finalize_intostrided-opteinsum/src/expr.rs:eval_pair_alloc— would need to support returning non-contiguous results