Skip to content

v2 backend architecture: 2-level IR with StableHLO cut point #21

@shinaoka

Description

@shinaoka

Summary

Brainstorming results for the v2 backend architecture. This resolves the core design questions from #20 (semiring backend contracts, compile-cache, StdTensorOp consistency) by establishing a clean 2-level IR architecture.

Key decisions

1. StableHLO-compatible IR as the single cut point

All computation (standard and custom algebra) is lowered to a StableHLO-compatible IR (Rust structs, in-process). This is the sole interface boundary:

  • XLA backend: receives the StableHLO IR directly. XLA handles all optimization internally.
  • faer / custom backends: tenferro's optimizing compiler lowers StableHLO IR → low-level IR → generic execution engine.
Surface API → Fragment → CompiledProgram
    ↓
StableHLO IR (Rust struct, complete, standard-compliant)  ← single cut point
    ├─ XLA backend: pass directly to XLA
    └─ faer / custom backends:
         Optimizing compiler → Low-level IR → Generic engine → Backend trait

2. Optimizing compiler (algebra-agnostic)

Transforms StableHLO IR to low-level IR. Key passes (modeled after XLA):

Pass Role
TransposeFolding Absorb Transpose into DotGeneral's dimension_numbers
DotDecomposer Canonicalize DotGeneral → Transpose + Reshape + BatchedGemm
Contiguous materialization Insert physical copy where needed

These passes are algebra-agnostic — they work identically for standard and custom algebras.

3. Low-level IR: contiguous, column-major

All buffers in the low-level IR are contiguous column-major. No stride tricks at this level.

Key instructions:

  • BatchedGemm(batch, m, n, k) — contiguous [batch, M, K] × [batch, K, N]
  • ReduceSum(axes) — contiguous input
  • Permute(perm) / Copy — physical memory reordering
  • Reshape — noop if element count unchanged, else copy

4. Generic execution engine

Interprets low-level IR step by step. Generic over backend trait. No kernel fusion — simple sequential dispatch.

fn execute_instruction<Alg, B: SemiringCore<Alg>>(inst: &LowLevelInst, ...) {
    match inst {
        BatchedGemm { .. } => {
            // Try optional fast path first
            if let Some(fp) = fast_path {
                if fp.contract(&desc, a, b, c) { return; }
            }
            // Fallback to required trait
            backend.batched_gemm(plan, a, b, c);
        }
        ReduceSum { axes } => backend.reduce_sum(axes, input, output),
        Permute { perm } => { /* physical memory copy */ }
        Reshape { .. } => { /* noop or copy */ }
    }
}

5. Custom algebra backend: minimal contract

Required Optional (fast path)
batched_gemm() contract() (direct contraction, skips Transpose+Reshape+BatchedGemm)
reduce_sum() elementwise_mul() (faster than degenerate K=1 BatchedGemm)

This matches tenferro v1's TensorSemiringCore + TensorSemiringFastPath design. v1 implementation should be reused as much as possible.

Why only these two:

  • BatchedGemm handles all contraction (⊕ and ⊗ inside)
  • ReduceSum handles all dimension reduction (⊕ inside)
  • Add (standalone): not needed for einsum pipeline; AD cotangent accumulation doesn't apply to custom algebras
  • Mul (standalone): expressible as degenerate BatchedGemm (K=1); optional fast path for efficiency
  • DotGeneral: decomposed by the compiler into Transpose + Reshape + BatchedGemm
  • Transpose, Reshape, BroadcastInDim: handled by the common infrastructure (algebra-independent)

6. Low-level IR and trait methods need NOT be 1:1

The engine can pattern-match sequences in the low-level IR and dispatch to higher-level backend methods:

  • IR has Transpose + Reshape + BatchedGemm (3 instructions)
  • If SemiringFastPath::contract() is available: dispatch all 3 as a single call
  • Otherwise: execute each instruction sequentially

7. Tensor allows arbitrary strides; contiguous at IR boundary

  • tenferro::Tensor supports arbitrary strides (zero-copy permute, slice, view)
  • At the StableHLO IR entry point, non-contiguous inputs are materialized to contiguous
  • All computation below the IR boundary operates on contiguous column-major buffers
  • eval() output is always contiguous

8. Column-major throughout

tenferro standardizes on column-major (Fortran) layout at all levels:

  • Tensor storage: column-major
  • Low-level IR: column-major contiguous
  • StableHLO IR: layout annotation specifies column-major
  • XLA backend: XLA receives column-major layout specification and handles internal layout assignment (may use row-major internally for GPU, returns column-major output)
  • BLAS/LAPACK/faer: column-major native — no conversion needed
  • tropical-gemm: column-major native

Relationship to existing code

v1 (origin/main) v2
einsum engine's lazy permutation Optimizing compiler (TransposeFolding)
prepare_one_operand + fusability check SemiringFastPath::contract()
SemiringCoreDescriptor Low-level IR instruction set
TensorSemiringCore SemiringCore trait (required)
TensorSemiringFastPath SemiringFastPath trait (optional)
EinsumBackend = Core + FastPath Generic execution engine

Evidence from XLA / JAX

  • JAX minimizes Transpose generation by trying both operand orders and using DotGeneral(dimension_numbers) to encode contraction axes without physical transpose
  • XLA has a multi-pass pipeline: DotDimensionSorter → DotDecomposer → TransposeFolding (×3) → DotMerger → AlgebraicSimplifier to eliminate unnecessary transposes
  • XLA's DotDecomposer canonicalizes arbitrary DotGeneral to [batch, M, K] × [batch, K, N] form — exactly BatchedGemm
  • StableHLO's dot_general does NOT include standalone ReduceSum — non-contracting dimensions always appear in the output. ij,jk->i requires dot_general(ij,jk->ik) + reduce(k).

Evidence from existing libraries

  • tropical-gemm: implements only BatchedGemm (pure GEMM library). No reduce, no elementwise ops. This covers the heaviest kernel.
  • omeinsum-rs: N-ary einsum engine. Its contract_binary is reshape→GEMM→reshape. Reduction of non-contracting dims is handled by the einsum engine decomposing into GEMM + unary reduce steps.

Resolves

Resolves the design questions in #20:

  1. Custom semiring minimum contract: batched_gemm + reduce_sum. Structural ops (Transpose, Reshape, BroadcastInDim) are common infrastructure.
  2. Compile-cache identity: tenferro Engine owns a normalized compile-cache layer above computegraph (normalizes away InputKey/DiffPassId). computegraph's GlobalValKey-based cache stays unchanged.
  3. StdTensorOp consistency: StableHLO IR is the single source of truth. StdTensorOp lowers to it via lower_to_stablehlo(). Enum definition and lowering tables must be unified against primitive-catalog.md.
  4. Descriptor + plan/execute as v2 pattern: Yes, via SemiringCore + SemiringFastPath traits (same pattern as v1).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions