-
Notifications
You must be signed in to change notification settings - Fork 0
v2 backend architecture: 2-level IR with StableHLO cut point #21
Description
Summary
Brainstorming results for the v2 backend architecture. This resolves the core design questions from #20 (semiring backend contracts, compile-cache, StdTensorOp consistency) by establishing a clean 2-level IR architecture.
Key decisions
1. StableHLO-compatible IR as the single cut point
All computation (standard and custom algebra) is lowered to a StableHLO-compatible IR (Rust structs, in-process). This is the sole interface boundary:
- XLA backend: receives the StableHLO IR directly. XLA handles all optimization internally.
- faer / custom backends: tenferro's optimizing compiler lowers StableHLO IR → low-level IR → generic execution engine.
Surface API → Fragment → CompiledProgram
↓
StableHLO IR (Rust struct, complete, standard-compliant) ← single cut point
├─ XLA backend: pass directly to XLA
└─ faer / custom backends:
Optimizing compiler → Low-level IR → Generic engine → Backend trait
2. Optimizing compiler (algebra-agnostic)
Transforms StableHLO IR to low-level IR. Key passes (modeled after XLA):
| Pass | Role |
|---|---|
| TransposeFolding | Absorb Transpose into DotGeneral's dimension_numbers |
| DotDecomposer | Canonicalize DotGeneral → Transpose + Reshape + BatchedGemm |
| Contiguous materialization | Insert physical copy where needed |
These passes are algebra-agnostic — they work identically for standard and custom algebras.
3. Low-level IR: contiguous, column-major
All buffers in the low-level IR are contiguous column-major. No stride tricks at this level.
Key instructions:
BatchedGemm(batch, m, n, k)— contiguous[batch, M, K] × [batch, K, N]ReduceSum(axes)— contiguous inputPermute(perm)/Copy— physical memory reorderingReshape— noop if element count unchanged, else copy
4. Generic execution engine
Interprets low-level IR step by step. Generic over backend trait. No kernel fusion — simple sequential dispatch.
fn execute_instruction<Alg, B: SemiringCore<Alg>>(inst: &LowLevelInst, ...) {
match inst {
BatchedGemm { .. } => {
// Try optional fast path first
if let Some(fp) = fast_path {
if fp.contract(&desc, a, b, c) { return; }
}
// Fallback to required trait
backend.batched_gemm(plan, a, b, c);
}
ReduceSum { axes } => backend.reduce_sum(axes, input, output),
Permute { perm } => { /* physical memory copy */ }
Reshape { .. } => { /* noop or copy */ }
}
}5. Custom algebra backend: minimal contract
| Required | Optional (fast path) |
|---|---|
batched_gemm() |
contract() (direct contraction, skips Transpose+Reshape+BatchedGemm) |
reduce_sum() |
elementwise_mul() (faster than degenerate K=1 BatchedGemm) |
This matches tenferro v1's TensorSemiringCore + TensorSemiringFastPath design. v1 implementation should be reused as much as possible.
Why only these two:
BatchedGemmhandles all contraction (⊕ and ⊗ inside)ReduceSumhandles all dimension reduction (⊕ inside)Add(standalone): not needed for einsum pipeline; AD cotangent accumulation doesn't apply to custom algebrasMul(standalone): expressible as degenerate BatchedGemm (K=1); optional fast path for efficiencyDotGeneral: decomposed by the compiler into Transpose + Reshape + BatchedGemmTranspose,Reshape,BroadcastInDim: handled by the common infrastructure (algebra-independent)
6. Low-level IR and trait methods need NOT be 1:1
The engine can pattern-match sequences in the low-level IR and dispatch to higher-level backend methods:
- IR has
Transpose + Reshape + BatchedGemm(3 instructions) - If
SemiringFastPath::contract()is available: dispatch all 3 as a single call - Otherwise: execute each instruction sequentially
7. Tensor allows arbitrary strides; contiguous at IR boundary
tenferro::Tensorsupports arbitrary strides (zero-copy permute, slice, view)- At the StableHLO IR entry point, non-contiguous inputs are materialized to contiguous
- All computation below the IR boundary operates on contiguous column-major buffers
eval()output is always contiguous
8. Column-major throughout
tenferro standardizes on column-major (Fortran) layout at all levels:
Tensorstorage: column-major- Low-level IR: column-major contiguous
- StableHLO IR: layout annotation specifies column-major
- XLA backend: XLA receives column-major layout specification and handles internal layout assignment (may use row-major internally for GPU, returns column-major output)
- BLAS/LAPACK/faer: column-major native — no conversion needed
- tropical-gemm: column-major native
Relationship to existing code
| v1 (origin/main) | v2 |
|---|---|
| einsum engine's lazy permutation | Optimizing compiler (TransposeFolding) |
prepare_one_operand + fusability check |
SemiringFastPath::contract() |
SemiringCoreDescriptor |
Low-level IR instruction set |
TensorSemiringCore |
SemiringCore trait (required) |
TensorSemiringFastPath |
SemiringFastPath trait (optional) |
EinsumBackend = Core + FastPath |
Generic execution engine |
Evidence from XLA / JAX
- JAX minimizes Transpose generation by trying both operand orders and using
DotGeneral(dimension_numbers)to encode contraction axes without physical transpose - XLA has a multi-pass pipeline:
DotDimensionSorter → DotDecomposer → TransposeFolding (×3) → DotMerger → AlgebraicSimplifierto eliminate unnecessary transposes - XLA's DotDecomposer canonicalizes arbitrary
DotGeneralto[batch, M, K] × [batch, K, N]form — exactly BatchedGemm - StableHLO's
dot_generaldoes NOT include standalone ReduceSum — non-contracting dimensions always appear in the output.ij,jk->irequiresdot_general(ij,jk->ik)+reduce(k).
Evidence from existing libraries
- tropical-gemm: implements only BatchedGemm (pure GEMM library). No reduce, no elementwise ops. This covers the heaviest kernel.
- omeinsum-rs: N-ary einsum engine. Its
contract_binaryis reshape→GEMM→reshape. Reduction of non-contracting dims is handled by the einsum engine decomposing into GEMM + unary reduce steps.
Resolves
Resolves the design questions in #20:
- Custom semiring minimum contract:
batched_gemm+reduce_sum. Structural ops (Transpose, Reshape, BroadcastInDim) are common infrastructure. - Compile-cache identity: tenferro Engine owns a normalized compile-cache layer above computegraph (normalizes away
InputKey/DiffPassId). computegraph'sGlobalValKey-based cache stays unchanged. - StdTensorOp consistency: StableHLO IR is the single source of truth.
StdTensorOplowers to it vialower_to_stablehlo(). Enum definition and lowering tables must be unified against primitive-catalog.md. - Descriptor + plan/execute as v2 pattern: Yes, via
SemiringCore+SemiringFastPathtraits (same pattern as v1).