Skip to content

feat: add missing linalg ops with AD (oracle coverage) (#649)#652

Merged
shinaoka merged 5 commits intomainfrom
codex/linalg-ops-649
Apr 7, 2026
Merged

feat: add missing linalg ops with AD (oracle coverage) (#649)#652
shinaoka merged 5 commits intomainfrom
codex/linalg-ops-649

Conversation

@shinaoka
Copy link
Copy Markdown
Member

@shinaoka shinaoka commented Apr 7, 2026

Summary

  • Add Lu and Eig primitives with full backend + compiler + exec + AD pipeline
  • Add Convert op (StableHLO-compatible stablehlo.convert) for dtype conversion with adjoint AD rules (real↔complex are mutual adjoints)
  • Refactor Solve from primitive to LU + triangular_solve composition (matching JAX)
  • Add traced compositions: slogdet, det, inv, pinv, eigvalsh, eigvals, norm
  • Enable oracle replay for all new ops (247 f64 cases passed, 0 failed)

Oracle Results

Op passed notes
lu 20/20 new primitive
solve 24/24 refactored to LU composition
slogdet 9/9 LU composition
det 9/9 slogdet composition
inv 8/8 solve(A, I)
eigvalsh 8/8 eigh composition
eigvals 8/8 eig composition, complex output
pinv 24/24 SVD composition
norm 102/102 Frobenius, p-norms, induced, nuclear
eig 0 passed, 8 skipped eigenvector derivatives not implemented (JAX parity)

Design Decisions

  • JAX alignment: primitive/composition split matches JAX (lu, eig = primitives; solve, slogdet, det, inv, pinv, norm = compositions)
  • Convert op: clean solution for eig's f64→Complex64 dtype promotion. Convert{f64→c64} transpose = Convert{c64→f64} (real-part extraction). Standard StableHLO op, not CustomCall
  • Eig AD: eigenvalue JVP only, no eigenvector derivatives (matches JAX's NotImplementedError)
  • LU outputs: (P, L, U, parity) — 4 tensors, parity is det(P) as ±1.0 scalar

Closes #649

Test plan

  • cargo test --workspace --release — all pass
  • Oracle replay for all 10 new ops — failed: 0
  • Coverage check (cargo llvm-cov)
  • Doc build (cargo doc --workspace --no-deps)

🤖 Generated with Claude Code

shinaoka and others added 4 commits April 7, 2026 09:38
…valsh, eigvals, pinv, norm)

- Add Lu primitive (backend + StdTensorOp + compiler/exec + AD JVP)
- Add Eig primitive (backend + StdTensorOp + compiler/exec + AD JVP eigenvalues only)
- Refactor Solve from primitive to LU + triangular_solve composition
- Add compositions: slogdet, det, inv, eigvalsh, eigvals, pinv, norm
- Add oracle replay for lu, slogdet, det, inv, eigvalsh, solve
- Add ReduceProd/ReduceMax/ReduceMin AD rules
- eig/eigvals/pinv/norm oracle replay partially enabled (needs fixes)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…igvals

- Add complex128 tensor decoding and comparison in oracle replay
- Fix pinv default rtol to match PyTorch, add pinv_with_rtol API
- Fix norm for ord=0, matrix induced norms (1, -1, inf)
- Skip eig/eigvals oracle (needs complex-safe JVP + transpose rule)
- Skip 2 norm ord=-inf matrix cases (subgradient mismatch)

Oracle results: pinv 24/24 passed, norm 100/102 passed (2 skipped)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…eig/norm AD

- Add StdTensorOp::Convert { from, to } with linearize + transpose rules
  (real↔complex are mutual adjoints)
- StableHLO: Convert maps to standard stablehlo.convert (not CustomCall)
- Fix linearize_eig to use Convert instead of Mul-by-complex-one hack
- Enable eig/eigvals oracle replay (eigvals 8/8 passed, eig skips vector derivatives per JAX)
- Fix norm ord=-inf matrix VJP (all 102 norm cases now pass)
- Add CPU convert kernel, backend trait method, compiler/exec wiring
- Add tests: Convert eval/JVP/VJP, eigvals gradient, norm regression

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@shinaoka shinaoka enabled auto-merge (squash) April 7, 2026 03:03
…7% coverage)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@shinaoka shinaoka merged commit 83f18bd into main Apr 7, 2026
5 checks passed
@shinaoka shinaoka deleted the codex/linalg-ops-649 branch April 7, 2026 04:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: add missing linalg ops with AD (oracle coverage)

1 participant