Summary
strided-einsum2/src/plan.rs classifies axes into 6 groups: batch, lo, ro, sum, left_trace, right_trace. The trace groups add complexity to the plan phase (extra vectors, extra permutation computation).
Proposal
Reduce to 4 groups (batch, lo, ro, sum). Handle trace axes lazily at runtime — detect and reduce them on-demand before the GEMM call, rather than pre-classifying them in the plan. The trace.rs module already provides reduce_trace_axes() which can be called inline.
This could potentially eliminate trace.rs as a separate module (~40 lines).
Risk
Low. Trace axes are rare (typically 0-2 per contraction) and already handled by a separate reduction step.