Skip to content

perf: avoid finalize copy in einsum2 by deferring output permutation (lazy permute) #128

@shinaoka

Description

@shinaoka

Summary

einsum2_dispatch always materializes the output via finalize_into when the output permutation makes the C dimensions non-fusable. This causes a full copy of the output tensor at the end of every contraction step with non-trivial output permutation.

A "lazy permute" approach (as implemented in tenferro-einsum) can avoid this copy entirely for intermediate steps by returning a non-contiguous view instead of a contiguous array. The next step's prepare_input_view/owned handles the non-contiguous input, and if the input happens to be fusable for the next GEMM, the copy is eliminated completely.

Current behavior

In einsum2_dispatch (lib.rs):

// 3. Prepare output
let c_op = prepare_output_view(&mut c_perm, n_lo, n_ro, beta, ...)?;
// 4. GEMM
B::bgemm_contiguous_into(&mut c_op, &a_op, &b_op, ...)?;
// 5. Finalize — copies temp → c_perm when output was non-fusable
c_op.finalize_into(&mut c_perm)?;

When prepare_output_view detects non-fusable strides, it allocates a temp buffer and finalize_into copies the GEMM result back. For intermediate steps (beta=0), the copy-in is skipped but the copy-back always happens.

Proposed behavior

For intermediate contraction steps (where the output is consumed by a subsequent step), skip finalize_into and instead return the GEMM output as a non-contiguous view with rearranged strides (lazy permute). The downstream step's prepare_input_owned already handles non-contiguous inputs.

Benchmark evidence

Comparing strided-rs vs tenferro-einsum (which implements lazy permute) on gm_queen5_5_3.wcsp:

Strategy strided-rs tenferro Ratio
opt_flops (148 steps, large intermediates) 8116ms 7083ms 0.87x
opt_size (159 steps, small intermediates) 2426ms 2753ms 1.13x
  • opt_flops: tenferro is 13% faster — large intermediates make the finalize copy expensive. Lazy permute avoids ~GB of total copies across 148 steps.
  • opt_size: strided-rs is 13% faster — small intermediates make finalize copies cheap, and tenferro's multi-layer dispatch overhead (~2ms/step) dominates.

The opt_flops result demonstrates that eliminating finalize copies can yield significant speedups for workloads with large intermediate tensors.

Affected code

  • strided-einsum2/src/lib.rs: einsum2_dispatch — the finalize_into call
  • strided-einsum2/src/contiguous.rs: ContiguousOperandMut::finalize_into
  • strided-opteinsum/src/expr.rs: eval_pair_alloc — would need to support returning non-contiguous results

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions