Skip to content

Simplify axis classification: reduce 6 groups to 4 + lazy trace #126

@shinaoka

Description

@shinaoka

Summary

strided-einsum2/src/plan.rs classifies axes into 6 groups: batch, lo, ro, sum, left_trace, right_trace. The trace groups add complexity to the plan phase (extra vectors, extra permutation computation).

Proposal

Reduce to 4 groups (batch, lo, ro, sum). Handle trace axes lazily at runtime — detect and reduce them on-demand before the GEMM call, rather than pre-classifying them in the plan. The trace.rs module already provides reduce_trace_axes() which can be called inline.

This could potentially eliminate trace.rs as a separate module (~40 lines).

Risk

Low. Trace axes are rare (typically 0-2 per contraction) and already handled by a separate reduction step.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions